-
Notifications
You must be signed in to change notification settings - Fork 0
/
gpt_data_generate.py
72 lines (57 loc) · 1.93 KB
/
gpt_data_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import random
import pandas as pd
from openai import AzureOpenAI
from prompts import SEED_EXAMPLES, generation_prompt
# WITH YOUR OWN SETUP
os.environ["AZURE_OPENAI_KEY"] = ""
os.environ["AZURE_OPENAI_ENDPOINT"] = ""
API_BASE = os.getenv("AZURE_OPENAI_ENDPOINT")
API_KEY = os.getenv("AZURE_OPENAI_KEY")
DEPLOYMENT_NAME = 'got-4o'
API_VERSION = '2024-05-01-preview'
client = AzureOpenAI(
api_key=API_KEY,
api_version=API_VERSION,
base_url=f"{API_BASE}/openai/deployments/{DEPLOYMENT_NAME}"
)
# Function to process a single row and return the results
def process_row(note):
examples = random.sample(SEED_EXAMPLES, 3)
prompt = generation_prompt(examples[0], examples[1], examples[2], note)
generated_disease = ""
try:
response = client.chat.completions.create(
model=DEPLOYMENT_NAME,
messages=[
{"role": "system", "content": "You are a helpful radiologist."},
{"role": "user", "content": prompt}
],
max_tokens=150,
temperature=0,
)
except Exception as e:
print(f"Error generating row {index}: {e}")
raw_output = response.choices[0].message.content.strip()
generated_disease = raw_output.strip().splitlines()[0]
start = generated_disease.find('[')
generated_disease = generated_disease[start:]
result = {
"note": note,
"generated_disease": generated_disease,
"raw_output": raw_output,
}
return result
df = pd.read_csv("./notes.csv")
results = []
for index, row in df.iterrows():
print(f"Processing row {index}...")
try:
result = process_row(row['note'])
if result is not None:
results.append(result)
except Exception as e:
print(f"Error processing row {index}: {e}")
df_results = pd.DataFrame(results)
df_results.to_csv("gpt_annotated.csv", index=False)
print("Processing complete. Result saved.")