Skip to content

Commit

Permalink
Update rank_gpt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnweiwei authored Jan 20, 2024
1 parent bc2a56f commit 582ea30
Showing 1 changed file with 71 additions and 77 deletions.
148 changes: 71 additions & 77 deletions rank_gpt.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
import copy
from tqdm import tqdm
import time
import openai
from openai import OpenAI
import json
import tiktoken

try:
import litellm
from litellm import completion
except:
litellm = None
completion = openai.ChatCompletion.create


class SafeOpenai:
class OpenaiClient:
def __init__(self, keys=None, start_id=None, proxy=None):
from openai import OpenAI
import openai
if isinstance(keys, str):
keys = [keys]
if keys is None:
Expand All @@ -30,19 +22,13 @@ def __init__(self, keys=None, start_id=None, proxy=None):
def chat(self, *args, return_text=False, reduce_length=False, **kwargs):
while True:
try:
model = args[0] if len(args) > 0 else kwargs["model"]
if "gpt" in model:
completion = self.client.chat.completions.create(*args, **kwargs, timeout=30)
elif model in litellm.model_list:
completion = completion(*args, **kwargs, api_key=self.api_key, force_timeout=30)
completion = self.client.chat.completions.create(*args, **kwargs, timeout=30)
break
except Exception as e:
print(str(e))
if "This model's maximum context length is" in str(e):
print('reduce_length')
return 'ERROR::reduce_length'
self.key_id = (self.key_id + 1) % len(self.key)
self.client = OpenAI(api_key=self.api_key)
time.sleep(0.1)
if return_text:
completion = completion.choices[0].message.content
Expand All @@ -60,52 +46,57 @@ def text(self, *args, return_text=False, reduce_length=False, **kwargs):
if "This model's maximum context length is" in str(e):
print('reduce_length')
return 'ERROR::reduce_length'
self.key_id = (self.key_id + 1) % len(self.key)
self.client = OpenAI(api_key=self.api_key)
time.sleep(0.1)
if return_text:
completion = completion.choices[0].text
return completion

def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
"""Returns the number of tokens used by a list of messages."""
if model == "gpt-3.5-turbo":
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
tokens_per_message, tokens_per_name = 0, 0

try:
encoding = tiktoken.get_encoding(model)
except:
encoding = tiktoken.get_encoding("cl100k_base")

num_tokens = 0
if isinstance(messages, list):
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
else:
num_tokens += len(encoding.encode(messages))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

class ClaudeClient:
def __init__(self, keys):
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
self.anthropic = Anthropic(api_key=keys)

def chat(self, messages, return_text=True, max_tokens=300, *args, **kwargs):
system = ' '.join([turn['content'] for turn in messages if turn['role'] == 'system'])
messages = [turn for turn in messages if turn['role'] != 'system']
if len(system) == 0:
system = None
completion = self.anthropic.beta.messages.create(messages=messages, system=system, max_tokens=max_tokens, *args, **kwargs)
if return_text:
completion = completion.content[0].text
return completion

def text(self, max_tokens=None, return_text=True, *args, **kwargs):
completion = self.anthropic.beta.messages.create(max_tokens_to_sample=max_tokens, *args, **kwargs)
if return_text:
completion = completion.completion
return completion

def max_tokens(model):
if 'gpt-4' in model:
return 8192
else:
return 4096

class LitellmClient:
def __init__(self, keys=None):
self.api_key = keys

def chat(self, return_text=True, *args, **kwargs):
from litellm import completion
response = completion(api_key=self.api_key, *args, **kwargs)
if return_text:
response = response.choices[0].message.content
return response

def convert_messages_to_prompt(messages):
# used for completion model (not chat model)
prompt = ''
for turn in messages:
if turn['role'] == 'system':
prompt += f"{turn['content']}\n\n"
elif turn['role'] == 'user':
prompt += f"{turn['content']}\n\n"
else: # 'assistant'
pass
prompt += "The ranking results of the 20 passages (only identifiers) is:"
return prompt


def run_retriever(topics, searcher, qrels=None, k=100, qid=None):
Expand Down Expand Up @@ -164,29 +155,32 @@ def create_permutation_instruction(item=None, rank_start=0, rank_end=100, model_
num = len(item['hits'][rank_start: rank_end])

max_length = 300
while True:
messages = get_prefix_prompt(query, num)
rank = 0
for hit in item['hits'][rank_start: rank_end]:
rank += 1
content = hit['content']
content = content.replace('Title: Content: ', '')
content = content.strip()
# For Japanese should cut by character: content = content[:int(max_length)]
content = ' '.join(content.split()[:int(max_length)])
messages.append({'role': 'user', 'content': f"[{rank}] {content}"})
messages.append({'role': 'assistant', 'content': f'Received passage [{rank}].'})
messages.append({'role': 'user', 'content': get_post_prompt(query, num)})

if num_tokens_from_messages(messages, model_name) <= max_tokens(model_name) - 200:
break
else:
max_length -= 1

messages = get_prefix_prompt(query, num)
rank = 0
for hit in item['hits'][rank_start: rank_end]:
rank += 1
content = hit['content']
content = content.replace('Title: Content: ', '')
content = content.strip()
# For Japanese should cut by character: content = content[:int(max_length)]
content = ' '.join(content.split()[:int(max_length)])
messages.append({'role': 'user', 'content': f"[{rank}] {content}"})
messages.append({'role': 'assistant', 'content': f'Received passage [{rank}].'})
messages.append({'role': 'user', 'content': get_post_prompt(query, num)})

return messages


def run_llm(messages, api_key=None, model_name="gpt-3.5-turbo"):
agent = SafeOpenai(api_key)
if 'gpt' in model_name:
Client = OpenaiClient
# elif 'claude' in model_name:
# Client = ClaudeClient
else:
Client = LitellmClient

agent = Client(api_key)
response = agent.chat(model=model_name, messages=messages, temperature=0, return_text=True)
return response

Expand Down Expand Up @@ -229,7 +223,7 @@ def receive_permutation(item, permutation, rank_start=0, rank_end=100):

def permutation_pipeline(item=None, rank_start=0, rank_end=100, model_name='gpt-3.5-turbo', api_key=None):
messages = create_permutation_instruction(item=item, rank_start=rank_start, rank_end=rank_end,
model_name=model_name) # chan
model_name=model_name) # chan
permutation = run_llm(messages, api_key=api_key, model_name=model_name)
item = receive_permutation(item, permutation, rank_start=rank_start, rank_end=rank_end)
return item
Expand Down

0 comments on commit 582ea30

Please sign in to comment.