Skip to content

Commit

Permalink
增加对Google Gemini模型使用中转api地址的支持
Browse files Browse the repository at this point in the history
  • Loading branch information
BetterAndBetterII committed Feb 5, 2024
1 parent b760d66 commit c37c6d4
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 0 deletions.
151 changes: 151 additions & 0 deletions custom/llms/GeminiLLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import (
Any,
Optional,
Sequence,
)

from llama_index.bridge.pydantic import Field
from llama_index.llms.base import (
llm_chat_callback,
llm_completion_callback,
)
from llama_index.llms.llm import LLM
from llama_index.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
)

DEFAULT_MODEL = "gemini-pro"

import google.generativeai as genai
import os

class Gemini(LLM):
model_name: str = Field(
default=DEFAULT_MODEL, description="The Gemini model to use."
)

max_tokens: Optional[int] = Field(
description="The maximum number of tokens to generate.",
gt=0,
)

temperature: float = Field(
default=0.1,
description="The temperature to use for sampling.",
gt=0,
)

api_base: str = Field(
default="generativelanguage.googleapis.com",
description="The base URL for the Gemini API.",
)

def __init__(
self,
model_name: str = DEFAULT_MODEL,
temperature: float = 0.1,
max_tokens: Optional[int] = None,
api_base: str = "generativelanguage.googleapis.com",
**kwargs: Any,
) -> None:
if model_name.find("models/") == -1:
model_name = f"models/{model_name}"

if os.getenv("GOOGLE_API_BASE") is not None:
api_base = os.getenv("GOOGLE_API_BASE")

super().__init__(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
api_base=api_base,
**kwargs,
)


def call_with_prompt(self, prompt):
# export GOOGLE_API_KEY="YOUR_KEY"
# export GOOGLE_API_BASE="generativelanguage.googleapis.com"
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"),
client_options={
"api_endpoint": self.api_base
},
transport='rest')
model = genai.GenerativeModel(self.model_name,
generation_config=genai.GenerationConfig(
temperature=self.temperature,
max_output_tokens=self.max_tokens
))
response = model.generate_content(prompt)

# 如果API一直有问题
# 可以直接复制prompt到网页去问
# print(prompt)

# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code
# and message from code and message.
if response is not None:
return response.text

@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
answer = self.call_with_prompt(prompt)

return CompletionResponse(
text=answer,
)

@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=6000,
num_output=self.max_tokens or -1,
# is_chat_model=is_chat_model(model=self._get_model_name()),
is_chat_model=False,
is_function_calling_model=False,
# is_function_calling_model=is_function_calling_model(
# model=self._get_model_name()
# ),
model_name=self.model_name,
)

# 下面是实现Interface必要的方法
# 但这里用不到,所以都是pass
@llm_completion_callback()
async def astream_complete(self) -> CompletionResponseAsyncGen:
pass

async def _astream_chat(self) -> ChatResponseAsyncGen:
pass

@llm_chat_callback()
async def astream_chat(self) -> ChatResponseAsyncGen:
pass

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
pass

@llm_chat_callback()
def stream_chat(self) -> ChatResponseGen:
pass

@llm_completion_callback()
def stream_complete(self) -> CompletionResponseGen:
pass

@llm_chat_callback()
async def achat(self) -> ChatResponse:
pass

@llm_completion_callback()
async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
pass
13 changes: 13 additions & 0 deletions docs/Gemini.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# 使用Google Gemini Pro/Ultra作为大模型

(对Google Gemini Pro/Ultra的支持由[BetterAndBetterII](https://github.com/betterandbetterii)提供)

用法:
1. 申请Google Gemini Pro的API Key [API申请](https://makersuite.google.com/app/apikey) (只需要Google账号即可免费申请和使用)
2. 设置环境变量 GOOGLE_API_KEY="YOUR_KEY" (可选设置GOOGLE_BASE_URL="中转的API地址")
3. 修改配置文件config.yaml中的`name``gemini-pro``gemini-ultra`
4. 其他步骤与原教程相同。(ZillizPipeline方案类似)

(GeminiLLM.py参考QwenLLM.py编写)

由于llama index中的Gemini无法配置transport='rest',无法使用中转的API地址,所以用GeminiLLM.py手动实现llama的接口并支持修改传输的方法。
5 changes: 5 additions & 0 deletions executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from custom.history_sentence_window import HistorySentenceWindowNodeParser
from custom.llms.QwenLLM import QwenUnofficial
from custom.llms.GeminiLLM import Gemini

from pymilvus import MilvusClient

Expand Down Expand Up @@ -110,6 +111,8 @@ def __init__(self, config):
# 使用Qwen 通义千问模型
if config.llm.name == "qwen":
llm = QwenUnofficial(temperature=config.llm.temperature, model=config.llm.name, max_tokens=2048)
elif config.llm.name.find("gemini") != -1:
llm = Gemini(temperature=config.llm.temperature, model_name=config.llm.name, max_tokens=2048)
else:
api_base = None
if 'api_base' in config.llm:
Expand Down Expand Up @@ -235,6 +238,8 @@ def __init__(self, config):

if config.llm.name == "qwen":
llm = QwenUnofficial(temperature=config.llm.temperature, model=config.llm.name, max_tokens=2048)
elif config.llm.name.find("gemini") != -1:
llm = Gemini(model_name=config.llm.name, temperature=config.llm.temperature, max_tokens=2048)
else:
api_base = None
if 'api_base' in config.llm:
Expand Down

0 comments on commit c37c6d4

Please sign in to comment.