Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev -> main #514

Merged
merged 13 commits into from
Sep 22, 2024
Next Next commit
feat: Add Vertex AI embedder
  • Loading branch information
whiterabbit1983 committed Sep 19, 2024
commit 283148633dd77e12b94deb10637c1cdd2411a5c7
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from temporalio import activity

from ..clients import cozo
from ..clients import embed as embedder
from ..clients import vertexai as embedder
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload
Expand Down
18 changes: 18 additions & 0 deletions agents-api/agents_api/clients/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import litellm
from litellm import aembedding

from ..env import google_project_id, vertex_location

litellm.vertex_project = google_project_id
litellm.vertex_location = vertex_location


async def embed(
inputs: list[str], dimensions: int = 1024, join_inputs: bool = True
) -> list[list[float]]:
input = ["\n\n".join(inputs)] if join_inputs else inputs
response = await aembedding(
model="vertex_ai/text-embedding-004", input=input, dimensions=dimensions
)

return response.data or []
4 changes: 4 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@
temporal_endpoint: Any = env.str("TEMPORAL_ENDPOINT", default="localhost:7233")
temporal_task_queue: Any = env.str("TEMPORAL_TASK_QUEUE", default="julep-task-queue")

# Google cloud
google_project_id: str = env.str("GOOGLE_PROJECT_ID")
vertex_location: str = env.str("VERTEX_LOCATION", default="us-central1")

# Consolidate environment variables
environment: Dict[str, Any] = dict(
Expand All @@ -97,6 +100,7 @@
temporal_namespace=temporal_namespace,
embedding_model_id=embedding_model_id,
testing=testing,
google_project_id=google_project_id,
)

if debug or testing:
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from agents_api.autogen.Chat import ChatInput

from ...autogen.openapi_model import DocReference, History
from ...clients import embed
from ...clients import vertexai as embed
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ..docs.search_docs_hybrid import search_docs_hybrid
Expand Down Expand Up @@ -61,7 +61,7 @@ async def gather_messages(
return past_messages, []

# Search matching docs
[query_embedding, *_] = await embed.embed(
query_embedding = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID

import agents_api.clients.embed as embedder
import agents_api.clients.vertexai as embedder

from ...autogen.openapi_model import (
EmbedQueryRequest,
Expand Down
3 changes: 2 additions & 1 deletion agents-api/tests/test_chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from ward import test

from agents_api.autogen.openapi_model import ChatInput, CreateSessionRequest
from agents_api.clients import embed, litellm
from agents_api.clients import litellm
from agents_api.clients import vertexai as embed
from agents_api.common.protocol.sessions import ChatContext
from agents_api.models.chat.gather_messages import gather_messages
from agents_api.models.chat.prepare_chat_context import prepare_chat_context
Expand Down
6 changes: 6 additions & 0 deletions llm-proxy/litellm-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ model_list:
api_base: os.environ/EMBEDDING_SERVICE_BASE
tags: ["free"]

- model_name: text-embedding-004
litellm_params:
model: vertex_ai/text-embedding-004
vertex_project: os.environ/GOOGLE_PROJECT_ID
vertex_location: os.environ/VERTEX_LOCATION


# -*= Free models =*-
# -------------------
Expand Down