Skip to content

Commit

Permalink
Merge pull request #15 from jjpn97/openai_migrate
Browse files Browse the repository at this point in the history
Openai migrate
  • Loading branch information
aypan17 authored Nov 2, 2024
2 parents 5796a59 + d4f17c5 commit 3c42d13
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 44 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ experiments/trajectories
__pycache__/
*.pyc
checkpoints*
models
model/*
machiavelli/game/agent/logs/*
*/logs/*
demo*
wandb*
env.sh
.env
31 changes: 19 additions & 12 deletions machiavelli/agent/lm_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def get_action_probabilities(self, obs: str, rew: float, done: bool, info: Mappi
logit_bias=logit_whitelist(valid_labels),
logprobs=num_actions
)
text = res['choices'][0]['text']
text = res.choices[0].message.content

print(text)
probs = get_class_probs(res, valid_labels)
print(probs)
Expand All @@ -144,8 +145,8 @@ def get_action_completion(self, obs: str, rew: float, done: bool, info: Mapping[
print(sys_prompt)
print(prompt)

def validation_fn(res):
text = res['choices'][0]['message']['content']
def validation_fn(res):
text = res.choices[0].message.content
print(f"Validating ({text}) ({valid_labels})")
try:
choice = re.findall(r'\d+', text.strip())[0] # Grab the first number from the string
Expand All @@ -167,7 +168,8 @@ def validation_fn(res):
self.random_actions += 1
return np.random.choice(num_actions)

text = res['choices'][0]['message']['content']
text = res.choices[0].message.content

choice = re.findall(r'\d+', text.strip())[0] # Grab the first number from the string
assert choice in valid_labels
return int(choice)
Expand Down Expand Up @@ -208,7 +210,8 @@ def get_action_probabilities(self, obs: str, rew: float, done: bool, info: Mappi
logit_bias=logit_whitelist(valid_labels),
logprobs=num_actions
)
text = res['choices'][0]['text']
text = res.choices[0].message

print(text)
probs = get_class_probs(res, valid_labels)
print(probs)
Expand All @@ -235,7 +238,8 @@ def get_action_completion(self, obs: str, rew: float, done: bool, info: Mapping[
print(prompt)

def validation_fn(res):
text = res['choices'][0]['message']['content']
# text = res['choices'][0]['message']['content']
text = res.choices[0].message.content
print(f"Validating ({text}) ({valid_labels})")
try:
choice = re.findall(r'\d+', text.strip())[0] # Grab the first number from the string
Expand All @@ -257,7 +261,8 @@ def validation_fn(res):
self.random_actions += 1
return np.random.choice(num_actions)

text = res['choices'][0]['message']['content']
# text = res['choices'][0]['message']['content']
text = res.choices[0].message.content
print(text)
choice = re.findall(r'\d+', text.strip())[0] # Grab the first number from the string
assert choice in valid_labels
Expand Down Expand Up @@ -300,7 +305,8 @@ def get_action_probabilities(self, obs: str, rew: float, done: bool, info: Mappi
max_tokens=1048,
temperature=self.temperature,
)
text = res['choices'][0]['text']

text = res.choices[0].message
print(text)

# Second inference call is just to extract the final answer from the above reasoning
Expand All @@ -316,8 +322,8 @@ def get_action_probabilities(self, obs: str, rew: float, done: bool, info: Mappi
logit_bias=logit_whitelist(valid_labels),
logprobs=num_actions
)
# print(res)
text = res['choices'][0]['text']

text = res.choices[0].message
print(text)
probs = get_class_probs(res, valid_labels)
print(probs)
Expand Down Expand Up @@ -362,7 +368,7 @@ def get_action_completion(self, obs: str, rew: float, done: bool, info: Mapping[
print(prompt)

def validation_fn(res):
text = res['choices'][0]['message']['content']
text = res.choices[0].message.content
print(f"Validating ({text}) ({valid_labels})")
try:
obj = json.loads(text.strip())
Expand All @@ -384,7 +390,8 @@ def validation_fn(res):
print("Unable to elicit valid response, choosing randomly")
self.random_actions += 1
return np.random.choice(num_actions)
text = res['choices'][0]['message']['content']

text = res.choices[0].message.content
print(text)
choice = str(json.loads(text.strip())['action'])
assert choice in valid_labels
Expand Down
63 changes: 32 additions & 31 deletions machiavelli/openai_helpers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import copy
import os
import time
from typing import Dict, List, Mapping, Optional
from typing import Dict, List, Mapping

import numpy as np
import openai
from openai import OpenAI
import numpy as np
import wandb
from transformers import GPT2TokenizerFast

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

Expand All @@ -16,8 +18,8 @@
"""

def completion_with_retries(model: str, prompt: str, max_retries: int = 10, retry_interval_sec: int = 20, validation_fn=None, **kwargs) -> Mapping:
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_API_ORG")
# TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.getenv("OPENAI_API_ORG"))'
# openai.organization = os.getenv("OPENAI_API_ORG")
print("Using OpenAI API with key:", openai.api_key, " | org:", openai.organization)

if model == 'davinci': # smaller context length; TODO: Handle this more neatly
Expand All @@ -33,23 +35,23 @@ def completion_with_retries(model: str, prompt: str, max_retries: int = 10, retr
if wandb.run is not None:
wandb.log({"openai_api": 0.5}, step=int(time.time() * 1000))
try:
res = openai.Completion.create(model=model, prompt=prompt, **kwargs)
res = client.completions.create(model=model, prompt=prompt, **kwargs)
if wandb.run is not None:
wandb.log({"openai_api": 1}, step=int(time.time() * 1000))

if validation_fn is not None:
if not validation_fn(res):
print(f"VALIDATION FAILED!\nResponse:\n{res}\nTrying again.")
continue
return res
except (
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.Timeout,
openai.error.TryAgain,
openai.error.OpenAIError,
openai.RateLimitError,
openai.ServiceUnavailableError,
openai.APIError,
openai.APIConnectionError,
openai.Timeout,
openai.TryAgain,
openai.OpenAIError,
) as e:
if wandb.run is not None:
wandb.log({"openai_api": 0}, step=int(time.time() * 1000))
Expand All @@ -59,8 +61,8 @@ def completion_with_retries(model: str, prompt: str, max_retries: int = 10, retr
return {}

def chat_completion_with_retries(model: str, sys_prompt: str, prompt: str, max_retries: int = 5, retry_interval_sec: int = 20, api_keys=None, validation_fn=None, **kwargs) -> Mapping:
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_API_ORG")
# TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.getenv("OPENAI_API_ORG"))'
# openai.organization = os.getenv("OPENAI_API_ORG")
print("Using OpenAI API with key:", openai.api_key, " | org:", openai.organization)

# Truncate if necessary
Expand Down Expand Up @@ -91,29 +93,28 @@ def chat_completion_with_retries(model: str, sys_prompt: str, prompt: str, max_r
if wandb.run is not None:
wandb.log({"openai_api": 0.5}, step=int(time.time() * 1000))
try:
res = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": prompt},
],
**kwargs)
res = client.chat.completions.create(model=model,
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": prompt},
],
**kwargs)
if wandb.run is not None:
wandb.log({"openai_api": 1}, step=int(time.time() * 1000))

if validation_fn is not None:
if not validation_fn(res):
print(f"VALIDATION FAILED!\nResponse:\n{res}\nTrying again.")
continue
return res
except (
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.Timeout,
openai.error.TryAgain,
openai.error.OpenAIError,
openai.RateLimitError,
openai.ServiceUnavailableError,
openai.APIError,
openai.APIConnectionError,
openai.Timeout,
openai.TryAgain,
openai.OpenAIError,
) as e:
if wandb.run is not None:
wandb.log({"openai_api": 0}, step=int(time.time() * 1000))
Expand All @@ -132,7 +133,7 @@ def logit_whitelist(labels: List[str]):

def get_class_probs(res: Dict, class_labels: List[str]):
# Extract logprobs for the output classes
choice, top_logprobs = res['choices'][0]['text'], res['choices'][0]['logprobs']['top_logprobs'][0]
choice, top_logprobs = res.choices[0].text, res.choices[0].logprobs.top_logprobs[0]
# Get logprobs for the output classes
# OAI returns max 5 top_logprobs, so sometimes the label is not there and we need to fill it with negligible value
logprobs = [top_logprobs[label] if label in top_logprobs else -100 for label in class_labels]
Expand Down

0 comments on commit 3c42d13

Please sign in to comment.