-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b760d66
commit c37c6d4
Showing
3 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from typing import ( | ||
Any, | ||
Optional, | ||
Sequence, | ||
) | ||
|
||
from llama_index.bridge.pydantic import Field | ||
from llama_index.llms.base import ( | ||
llm_chat_callback, | ||
llm_completion_callback, | ||
) | ||
from llama_index.llms.llm import LLM | ||
from llama_index.llms.types import ( | ||
ChatMessage, | ||
ChatResponse, | ||
ChatResponseAsyncGen, | ||
ChatResponseGen, | ||
CompletionResponse, | ||
CompletionResponseAsyncGen, | ||
CompletionResponseGen, | ||
LLMMetadata, | ||
) | ||
|
||
DEFAULT_MODEL = "gemini-pro" | ||
|
||
import google.generativeai as genai | ||
import os | ||
|
||
class Gemini(LLM): | ||
model_name: str = Field( | ||
default=DEFAULT_MODEL, description="The Gemini model to use." | ||
) | ||
|
||
max_tokens: Optional[int] = Field( | ||
description="The maximum number of tokens to generate.", | ||
gt=0, | ||
) | ||
|
||
temperature: float = Field( | ||
default=0.1, | ||
description="The temperature to use for sampling.", | ||
gt=0, | ||
) | ||
|
||
api_base: str = Field( | ||
default="generativelanguage.googleapis.com", | ||
description="The base URL for the Gemini API.", | ||
) | ||
|
||
def __init__( | ||
self, | ||
model_name: str = DEFAULT_MODEL, | ||
temperature: float = 0.1, | ||
max_tokens: Optional[int] = None, | ||
api_base: str = "generativelanguage.googleapis.com", | ||
**kwargs: Any, | ||
) -> None: | ||
if model_name.find("models/") == -1: | ||
model_name = f"models/{model_name}" | ||
|
||
if os.getenv("GOOGLE_API_BASE") is not None: | ||
api_base = os.getenv("GOOGLE_API_BASE") | ||
|
||
super().__init__( | ||
model_name=model_name, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
api_base=api_base, | ||
**kwargs, | ||
) | ||
|
||
|
||
def call_with_prompt(self, prompt): | ||
# export GOOGLE_API_KEY="YOUR_KEY" | ||
# export GOOGLE_API_BASE="generativelanguage.googleapis.com" | ||
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"), | ||
client_options={ | ||
"api_endpoint": self.api_base | ||
}, | ||
transport='rest') | ||
model = genai.GenerativeModel(self.model_name, | ||
generation_config=genai.GenerationConfig( | ||
temperature=self.temperature, | ||
max_output_tokens=self.max_tokens | ||
)) | ||
response = model.generate_content(prompt) | ||
|
||
# 如果API一直有问题 | ||
# 可以直接复制prompt到网页去问 | ||
# print(prompt) | ||
|
||
# The response status_code is HTTPStatus.OK indicate success, | ||
# otherwise indicate request is failed, you can get error code | ||
# and message from code and message. | ||
if response is not None: | ||
return response.text | ||
|
||
@llm_completion_callback() | ||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | ||
answer = self.call_with_prompt(prompt) | ||
|
||
return CompletionResponse( | ||
text=answer, | ||
) | ||
|
||
@property | ||
def metadata(self) -> LLMMetadata: | ||
return LLMMetadata( | ||
context_window=6000, | ||
num_output=self.max_tokens or -1, | ||
# is_chat_model=is_chat_model(model=self._get_model_name()), | ||
is_chat_model=False, | ||
is_function_calling_model=False, | ||
# is_function_calling_model=is_function_calling_model( | ||
# model=self._get_model_name() | ||
# ), | ||
model_name=self.model_name, | ||
) | ||
|
||
# 下面是实现Interface必要的方法 | ||
# 但这里用不到,所以都是pass | ||
@llm_completion_callback() | ||
async def astream_complete(self) -> CompletionResponseAsyncGen: | ||
pass | ||
|
||
async def _astream_chat(self) -> ChatResponseAsyncGen: | ||
pass | ||
|
||
@llm_chat_callback() | ||
async def astream_chat(self) -> ChatResponseAsyncGen: | ||
pass | ||
|
||
@llm_chat_callback() | ||
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: | ||
pass | ||
|
||
@llm_chat_callback() | ||
def stream_chat(self) -> ChatResponseGen: | ||
pass | ||
|
||
@llm_completion_callback() | ||
def stream_complete(self) -> CompletionResponseGen: | ||
pass | ||
|
||
@llm_chat_callback() | ||
async def achat(self) -> ChatResponse: | ||
pass | ||
|
||
@llm_completion_callback() | ||
async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# 使用Google Gemini Pro/Ultra作为大模型 | ||
|
||
(对Google Gemini Pro/Ultra的支持由[BetterAndBetterII](https://github.com/betterandbetterii)提供) | ||
|
||
用法: | ||
1. 申请Google Gemini Pro的API Key [API申请](https://makersuite.google.com/app/apikey) (只需要Google账号即可免费申请和使用) | ||
2. 设置环境变量 GOOGLE_API_KEY="YOUR_KEY" (可选设置GOOGLE_BASE_URL="中转的API地址") | ||
3. 修改配置文件config.yaml中的`name`为`gemini-pro`或`gemini-ultra`。 | ||
4. 其他步骤与原教程相同。(ZillizPipeline方案类似) | ||
|
||
(GeminiLLM.py参考QwenLLM.py编写) | ||
|
||
由于llama index中的Gemini无法配置transport='rest',无法使用中转的API地址,所以用GeminiLLM.py手动实现llama的接口并支持修改传输的方法。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters