Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added gemini-pro to the list of APIs #1

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions conversers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import common
from language_models import GPT, PaLM, HuggingFace, APIModelLlama7B, APIModelVicuna13B
from language_models import GPT, PaLM, HuggingFace, APIModelLlama7B, APIModelVicuna13B, GeminiPro
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import VICUNA_PATH, LLAMA_PATH, ATTACK_TEMP, TARGET_TEMP, ATTACK_TOP_P, TARGET_TOP_P, MAX_PARALLEL_STREAMS
Expand Down Expand Up @@ -216,6 +216,8 @@ def load_indiv_model(model_name):
lm = GPT(model_name)
elif model_name == "palm-2":
lm = PaLM(model_name)
elif model_name == "gemini-pro":
lm = GeminiPro(model_name)
elif model_name == 'llama-2-api-model':
lm = APIModelLlama7B(model_name)
elif model_name == 'vicuna-api-model':
Expand Down Expand Up @@ -282,11 +284,15 @@ def get_model_path_and_template(model_name):
"palm-2":{
"path":"palm-2",
"template":"palm-2"
},
"gemini-pro": {
"path": "gemini-pro",
"template": "gemini-pro"
}
}
path, template = full_model_dict[model_name]["path"], full_model_dict[model_name]["template"]
return path, template





69 changes: 66 additions & 3 deletions language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import gc
from typing import Dict, List
import google.generativeai as palm
import google.generativeai as genai
import urllib3
from copy import deepcopy

Expand Down Expand Up @@ -257,7 +257,7 @@ class PaLM():

def __init__(self, model_name) -> None:
self.model_name = model_name
palm.configure(api_key=self.API_KEY)
genai.configure(api_key=self.API_KEY)

def generate(self, conv: List,
max_n_tokens: int,
Expand All @@ -275,7 +275,7 @@ def generate(self, conv: List,
output = self.API_ERROR_OUTPUT
for _ in range(self.API_MAX_RETRY):
try:
completion = palm.chat(
completion = genai.chat(
messages=conv,
temperature=temperature,
top_p=top_p
Expand Down Expand Up @@ -303,3 +303,66 @@ def batched_generate(self,
temperature: float,
top_p: float = 1.0,):
return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]


class GeminiPro():
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
API_QUERY_SLEEP = 1
API_MAX_RETRY = 5
API_TIMEOUT = 20
default_output = "I'm sorry, but I cannot assist with that request."
API_KEY = os.getenv("PALM_API_KEY")

def __init__(self, model_name) -> None:
self.model_name = model_name
genai.configure(api_key=self.API_KEY)

def generate(self, conv: List,
max_n_tokens: int,
temperature: float,
top_p: float):
'''
Args:
conv: List of dictionaries,
max_n_tokens: int, max number of tokens to generate
temperature: float, temperature for sampling
top_p: float, top p for sampling
Returns:
str: generated response
'''
output = self.API_ERROR_OUTPUT
for _ in range(self.API_MAX_RETRY):
try:
model = genai.GenerativeModel(self.model_name)
output = model.generate_content(
contents = conv,
generation_config = genai.GenerationConfig(
candidate_count = 1,
temperature = temperature,
top_p = top_p,
max_output_tokens=max_n_tokens,
)
)

if output is None:
# If PaLM refuses to output and returns None, we replace it with a default output
output = self.default_output
else:
# Use this approximation since PaLM does not allow
# to specify max_tokens. Each token is approximately 4 characters.
output = output.text
break
except Exception as e:
print(type(e), e)
time.sleep(self.API_RETRY_SLEEP)

time.sleep(self.API_QUERY_SLEEP)
return output

def batched_generate(self,
convs_list: List[List[Dict]],
max_n_tokens: int,
temperature: float,
top_p: float = 1.0,):
return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]
4 changes: 3 additions & 1 deletion main_TAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def main(args):
"gpt-4",
'gpt-4-turbo',
'gpt-4-1106-preview', # This is same as gpt-4-turbo
"palm-2"]
"palm-2",
"gemini-pro",
]
)
parser.add_argument(
"--target-max-n-tokens",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pytz==2023.3.post1
torch==2.1.0
transformers==4.35.0
urllib3==2.1.0
google-generativeai==0.3.1