-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsample_prompts_by_action.py
138 lines (123 loc) · 5.98 KB
/
sample_prompts_by_action.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Sample a prompt for a specific action."""
import numpy as np
from logic.prompts import get_user_part_of_prompt
ACTION_2_FILENAME = {
# About
"self": "describe_self.txt",
"function": "describe_function.txt",
# Metadata
"show": "show.txt",
"model": "describe_model.txt",
"data_summary": "describe_data.txt",
"countdata": "count_data.txt",
"labels": "label.txt",
# Prediction
"predict": "predict.txt",
"random_predict": ["random_predict.txt", "random_predict_chatgpt.txt"],
"score": "score.txt",
"mistake_count": "mistake_count.txt",
"mistake_sample": "mistake_sample.txt",
# Understanding
"similar": ["similar.txt", "similar_chatgpt.txt"],
"keyword": "keywords.txt",
# Explanation
"explain": ["local_feature_importance.txt", "local_feature_importance_chatgpt.txt"],
"rationalization": ["rationalize.txt", "rationalize_chatgpt.txt"],
# Custom input
"custom_input_prediction": ["custom_input_prediction_chatgpt.txt", "custom_input_prediction.txt"],
"custom_input_feature_importance": ["custom_input_feature_importance_chatgpt.txt", "custom_input_feature_importance.txt"],
"c_rationale": ["custom_input_rationalize.txt", "custom_input_rationalize_chatgpt.txt"],
"c_cfe": ["custom_input_cfe.txt", "custom_input_cfe_chatgpt.txt"],
"c_data": ["custom_input_data_augmentation.txt", "custom_input_data_augmentation_chatgpt.txt"],
"c_sim": "custom_input_similarity.txt",
# Perturbation
"cfe": "cfe.txt",
"augment": ["augmentation_chatgpt.txt", "augmentation.txt"],
# QA
"qa_cfe": "qacfe.txt",
"qa_da": "qada.txt",
"qa_fa": "qafeature_attribution.txt",
"qa_rat": "qarationale.txt",
"qa_sim": "qasim.txt",
"qa_ci": "qacustominput.txt"
}
def replace_non_existent_id_with_real_id(prompt: str, real_ids: list[int]) -> str:
"""Attempts to replace ids that don't exist in the data with ones that do.
The problem is that prompt generation may create prompts with ids that don't occur
in the data. This is fine for the purposes of generating training data. However,
when the user clicks "generate an example question" we should try and make sure
that the example question uses ids that actually occur in the data, so they
don't get errors if they try and run the prompt (this seems a bit unsatisfying).
This function resolves this issues by trying to find occurrences of ids in the prompts
and replaces them with ids that are known to occur in the data.
Arguments:
prompt: The prompt to try and replace with an actual id in the data.
real_ids: A list of ids that occur in the data **in actuality**
Returns:
resolved_prompt: The prompt with ids replaced
"""
split_prompt = prompt.split()
for i in range(len(split_prompt)):
if split_prompt[i] in ["point", "id", "number", "for", "instance"]:
if i+1 < len(split_prompt):
# Second option is for sentence end punctuation case
if split_prompt[i+1].isnumeric() or split_prompt[i+1][:-1].isnumeric():
split_prompt[i+1] = str(np.random.choice(real_ids))
output = " ".join(split_prompt)
return output
def sample_prompt_for_action(action: str,
filename_to_prompt_ids: dict,
prompt_set: dict,
conversation) -> str:
"""Samples a prompt for a specific action.
Arguments:
prompt_set: The full prompt set. This is a dictionary that maps from **prompt** ids
to info about the prompt.
action: The action to sample a prompt for
filename_to_prompt_ids: a map from the prompt filenames to the id of a prompt.
Note, in this context, 'id' refers to the 'id' assigned
to each prompt and *not* an id of a data point.
Returns:
prompt: The sampled prompt
"""
real_ids = conversation.get_training_data_ids()
if action == "self":
return "Could you tell me a bit more about what this is?"
elif action == "function":
return "What can you do?"
elif action in ACTION_2_FILENAME:
filename_end = ACTION_2_FILENAME[action]
if type(filename_end) != str:
chosen_id = np.random.choice(len(filename_end))
filename_end = filename_end[chosen_id]
for filename in filename_to_prompt_ids:
if ("includes" not in filename) and filename.endswith(filename_end):
if conversation.used:
if "custom_input" in filename:
continue
prompt_ids = filename_to_prompt_ids[filename]
chosen_id = np.random.choice(prompt_ids)
i = 0
# Try to not return prompts that are not complete
# for the particular dataset (i.e., those that have
# a wildcard still in them with "{" )
prompt = prompt_set[chosen_id]["prompts"][0]
user_part = get_user_part_of_prompt(prompt)
while "{" in user_part or i < 100:
chosen_id = np.random.choice(prompt_ids)
i += 1
prompt = prompt_set[chosen_id]["prompts"][0]
user_part = get_user_part_of_prompt(prompt)
final_user_part = replace_non_existent_id_with_real_id(user_part, real_ids)
return final_user_part
elif ("includes" in filename) and ("includes" in filename_end) and filename.endswith(filename_end):
prompt_ids = filename_to_prompt_ids[filename]
chosen_id = np.random.choice(prompt_ids)
prompt = prompt_set[chosen_id]["prompts"][0]
user_part = get_user_part_of_prompt(prompt)
return user_part
message = f"Unable to filename ending in {filename_end}!"
raise NameError(message)
else:
message = f"Unknown action {action}"
raise NameError(message)