diff --git a/.gitignore b/.gitignore index 3cf1796..7e62c03 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,10 @@ experiments/trajectories __pycache__/ *.pyc checkpoints* -models +model/* machiavelli/game/agent/logs/* */logs/* demo* wandb* +env.sh +.env \ No newline at end of file diff --git a/machiavelli/agent/lm_agents.py b/machiavelli/agent/lm_agents.py index 8dc6a48..7297b36 100644 --- a/machiavelli/agent/lm_agents.py +++ b/machiavelli/agent/lm_agents.py @@ -118,7 +118,8 @@ def get_action_probabilities(self, obs: str, rew: float, done: bool, info: Mappi logit_bias=logit_whitelist(valid_labels), logprobs=num_actions ) - text = res['choices'][0]['text'] + text = res.choices[0].message.content + print(text) probs = get_class_probs(res, valid_labels) print(probs) @@ -144,8 +145,8 @@ def get_action_completion(self, obs: str, rew: float, done: bool, info: Mapping[ print(sys_prompt) print(prompt) - def validation_fn(res): - text = res['choices'][0]['message']['content'] + def validation_fn(res): + text = res.choices[0].message.content print(f"Validating ({text}) ({valid_labels})") try: choice = re.findall(r'\d+', text.strip())[0] # Grab the first number from the string @@ -167,7 +168,8 @@ def validation_fn(res): self.random_actions += 1 return np.random.choice(num_actions) - text = res['choices'][0]['message']['content'] + text = res.choices[0].message.content + choice = re.findall(r'\d+', text.strip())[0] # Grab the first number from the string assert choice in valid_labels return int(choice) @@ -208,7 +210,8 @@ def get_action_probabilities(self, obs: str, rew: float, done: bool, info: Mappi logit_bias=logit_whitelist(valid_labels), logprobs=num_actions ) - text = res['choices'][0]['text'] + text = res.choices[0].message + print(text) probs = get_class_probs(res, valid_labels) print(probs) @@ -235,7 +238,8 @@ def get_action_completion(self, obs: str, rew: float, done: bool, info: Mapping[ print(prompt) def validation_fn(res): - text = res['choices'][0]['message']['content'] + # text = res['choices'][0]['message']['content'] + text = res.choices[0].message.content print(f"Validating ({text}) ({valid_labels})") try: choice = re.findall(r'\d+', text.strip())[0] # Grab the first number from the string @@ -257,7 +261,8 @@ def validation_fn(res): self.random_actions += 1 return np.random.choice(num_actions) - text = res['choices'][0]['message']['content'] + # text = res['choices'][0]['message']['content'] + text = res.choices[0].message.content print(text) choice = re.findall(r'\d+', text.strip())[0] # Grab the first number from the string assert choice in valid_labels @@ -300,7 +305,8 @@ def get_action_probabilities(self, obs: str, rew: float, done: bool, info: Mappi max_tokens=1048, temperature=self.temperature, ) - text = res['choices'][0]['text'] + + text = res.choices[0].message print(text) # Second inference call is just to extract the final answer from the above reasoning @@ -316,8 +322,8 @@ def get_action_probabilities(self, obs: str, rew: float, done: bool, info: Mappi logit_bias=logit_whitelist(valid_labels), logprobs=num_actions ) - # print(res) - text = res['choices'][0]['text'] + + text = res.choices[0].message print(text) probs = get_class_probs(res, valid_labels) print(probs) @@ -362,7 +368,7 @@ def get_action_completion(self, obs: str, rew: float, done: bool, info: Mapping[ print(prompt) def validation_fn(res): - text = res['choices'][0]['message']['content'] + text = res.choices[0].message.content print(f"Validating ({text}) ({valid_labels})") try: obj = json.loads(text.strip()) @@ -384,7 +390,8 @@ def validation_fn(res): print("Unable to elicit valid response, choosing randomly") self.random_actions += 1 return np.random.choice(num_actions) - text = res['choices'][0]['message']['content'] + + text = res.choices[0].message.content print(text) choice = str(json.loads(text.strip())['action']) assert choice in valid_labels diff --git a/machiavelli/openai_helpers.py b/machiavelli/openai_helpers.py index b7e3610..dc734b9 100644 --- a/machiavelli/openai_helpers.py +++ b/machiavelli/openai_helpers.py @@ -1,13 +1,15 @@ import copy import os import time -from typing import Dict, List, Mapping, Optional +from typing import Dict, List, Mapping -import numpy as np import openai +from openai import OpenAI +import numpy as np import wandb from transformers import GPT2TokenizerFast +client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") @@ -16,8 +18,8 @@ """ def completion_with_retries(model: str, prompt: str, max_retries: int = 10, retry_interval_sec: int = 20, validation_fn=None, **kwargs) -> Mapping: - openai.api_key = os.getenv("OPENAI_API_KEY") - openai.organization = os.getenv("OPENAI_API_ORG") + # TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.getenv("OPENAI_API_ORG"))' + # openai.organization = os.getenv("OPENAI_API_ORG") print("Using OpenAI API with key:", openai.api_key, " | org:", openai.organization) if model == 'davinci': # smaller context length; TODO: Handle this more neatly @@ -33,23 +35,23 @@ def completion_with_retries(model: str, prompt: str, max_retries: int = 10, retr if wandb.run is not None: wandb.log({"openai_api": 0.5}, step=int(time.time() * 1000)) try: - res = openai.Completion.create(model=model, prompt=prompt, **kwargs) + res = client.completions.create(model=model, prompt=prompt, **kwargs) if wandb.run is not None: wandb.log({"openai_api": 1}, step=int(time.time() * 1000)) - + if validation_fn is not None: if not validation_fn(res): print(f"VALIDATION FAILED!\nResponse:\n{res}\nTrying again.") continue return res except ( - openai.error.RateLimitError, - openai.error.ServiceUnavailableError, - openai.error.APIError, - openai.error.APIConnectionError, - openai.error.Timeout, - openai.error.TryAgain, - openai.error.OpenAIError, + openai.RateLimitError, + openai.ServiceUnavailableError, + openai.APIError, + openai.APIConnectionError, + openai.Timeout, + openai.TryAgain, + openai.OpenAIError, ) as e: if wandb.run is not None: wandb.log({"openai_api": 0}, step=int(time.time() * 1000)) @@ -59,8 +61,8 @@ def completion_with_retries(model: str, prompt: str, max_retries: int = 10, retr return {} def chat_completion_with_retries(model: str, sys_prompt: str, prompt: str, max_retries: int = 5, retry_interval_sec: int = 20, api_keys=None, validation_fn=None, **kwargs) -> Mapping: - openai.api_key = os.getenv("OPENAI_API_KEY") - openai.organization = os.getenv("OPENAI_API_ORG") + # TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.getenv("OPENAI_API_ORG"))' + # openai.organization = os.getenv("OPENAI_API_ORG") print("Using OpenAI API with key:", openai.api_key, " | org:", openai.organization) # Truncate if necessary @@ -91,29 +93,28 @@ def chat_completion_with_retries(model: str, sys_prompt: str, prompt: str, max_r if wandb.run is not None: wandb.log({"openai_api": 0.5}, step=int(time.time() * 1000)) try: - res = openai.ChatCompletion.create( - model=model, - messages=[ - {"role": "system", "content": sys_prompt}, - {"role": "user", "content": prompt}, - ], - **kwargs) + res = client.chat.completions.create(model=model, + messages=[ + {"role": "system", "content": sys_prompt}, + {"role": "user", "content": prompt}, + ], + **kwargs) if wandb.run is not None: wandb.log({"openai_api": 1}, step=int(time.time() * 1000)) - + if validation_fn is not None: if not validation_fn(res): print(f"VALIDATION FAILED!\nResponse:\n{res}\nTrying again.") continue return res except ( - openai.error.RateLimitError, - openai.error.ServiceUnavailableError, - openai.error.APIError, - openai.error.APIConnectionError, - openai.error.Timeout, - openai.error.TryAgain, - openai.error.OpenAIError, + openai.RateLimitError, + openai.ServiceUnavailableError, + openai.APIError, + openai.APIConnectionError, + openai.Timeout, + openai.TryAgain, + openai.OpenAIError, ) as e: if wandb.run is not None: wandb.log({"openai_api": 0}, step=int(time.time() * 1000)) @@ -132,7 +133,7 @@ def logit_whitelist(labels: List[str]): def get_class_probs(res: Dict, class_labels: List[str]): # Extract logprobs for the output classes - choice, top_logprobs = res['choices'][0]['text'], res['choices'][0]['logprobs']['top_logprobs'][0] + choice, top_logprobs = res.choices[0].text, res.choices[0].logprobs.top_logprobs[0] # Get logprobs for the output classes # OAI returns max 5 top_logprobs, so sometimes the label is not there and we need to fill it with negligible value logprobs = [top_logprobs[label] if label in top_logprobs else -100 for label in class_labels]