Skip to content

Commit

Permalink
feat: add support for new cohere models in cohere and bedrock embeddi…
Browse files Browse the repository at this point in the history
…ng functions (lancedb#1335)

Fixes lancedb#1329

Will update docs on lancedb#1326
  • Loading branch information
AyushExel authored May 30, 2024
1 parent 1b2463c commit 16eff25
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/python/lancedb/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.arr

@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
"""
Generate the embeddings for the given texts
Expand Down
28 changes: 18 additions & 10 deletions python/python/lancedb/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class TextModel(LanceModel):
assumed_role: Union[str, None] = None
profile_name: Union[str, None] = None
role_session_name: str = "lancedb-embeddings"
source_input_type: str = "search_document"
query_input_type: str = "search_query"

if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat

Expand All @@ -87,21 +89,29 @@ def ndims(self):
# TODO: fix hardcoding
if self.name == "amazon.titan-embed-text-v1":
return 1536
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
elif self.name in [
"amazon.titan-embed-text-v2:0",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]:
# TODO: "amazon.titan-embed-text-v2:0" model supports dynamic ndims
return 1024
else:
raise ValueError(f"Unknown model name: {self.name}")
raise ValueError(f"Model {self.name} not supported")

def compute_query_embeddings(
self, query: str, *args, **kwargs
) -> List[List[float]]:
return self.compute_source_embeddings(query)
return self.compute_source_embeddings(query, input_type=self.query_input_type)

def compute_source_embeddings(
self, texts: TEXT, *args, **kwargs
) -> List[List[float]]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
# assume source input type if not passed by `compute_query_embeddings`
kwargs["input_type"] = kwargs.get("input_type") or self.source_input_type

return self.generate_embeddings(texts, **kwargs)

def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
Expand All @@ -121,11 +131,11 @@ def generate_embeddings(
"""
results = []
for text in texts:
response = self._generate_embedding(text)
response = self._generate_embedding(text, *args, **kwargs)
results.append(response)
return results

def _generate_embedding(self, text: str) -> List[float]:
def _generate_embedding(self, text: str, *args, **kwargs) -> List[float]:
"""
Get the embeddings for the given texts
Expand All @@ -141,14 +151,12 @@ def _generate_embedding(self, text: str) -> List[float]:
"""
# format input body for provider
provider = self.name.split(".")[0]
_model_kwargs = {}
input_body = {**_model_kwargs}
input_body = {**kwargs}
if provider == "cohere":
if "input_type" not in input_body.keys():
input_body["input_type"] = "search_document"
input_body["texts"] = [text]
else:
# includes common provider == "amazon"
input_body.pop("input_type", None)
input_body["inputText"] = text
body = json.dumps(input_body)

Expand Down
68 changes: 61 additions & 7 deletions python/python/lancedb/embeddings/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help
from .utils import api_key_not_found_help, TEXT


@register("cohere")
Expand All @@ -32,8 +32,36 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
Parameters
----------
name: str, default "embed-multilingual-v2.0"
The name of the model to use. See the Cohere documentation for
a list of available models.
The name of the model to use. List of acceptable models:
* embed-english-v3.0
* embed-multilingual-v3.0
* embed-english-light-v3.0
* embed-multilingual-light-v3.0
* embed-english-v2.0
* embed-english-light-v2.0
* embed-multilingual-v2.0
source_input_type: str, default "search_document"
The input type for the source column in the database
query_input_type: str, default "search_query"
The input type for the query column in the database
Cohere supports following input types:
| Input Type | Description |
|-------------------------|---------------------------------------|
| "`search_document`" | Used for embeddings stored in a vector|
| | database for search use-cases. |
| "`search_query`" | Used for embeddings of search queries |
| | run against a vector DB |
| "`semantic_similarity`" | Specifies the given text will be used |
| | for Semantic Textual Similarity (STS) |
| "`classification`" | Used for embeddings passed through a |
| | text classifier. |
| "`clustering`" | Used for the embeddings run through a |
| | clustering algorithm |
Examples
--------
Expand Down Expand Up @@ -61,14 +89,39 @@ class TextModel(LanceModel):
"""

name: str = "embed-multilingual-v2.0"
source_input_type: str = "search_document"
query_input_type: str = "search_query"
client: ClassVar = None

def ndims(self):
# TODO: fix hardcoding
return 768
if self.name in [
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v2.0",
]:
return 1024
elif self.name in ["embed-english-light-v3.0", "embed-multilingual-light-v3.0"]:
return 384
elif self.name == "embed-english-v2.0":
return 4096
elif self.name == "embed-multilingual-v2.0":
return 768
else:
raise ValueError(f"Model {self.name} not supported")

def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, input_type=self.query_input_type)

def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
input_type = (
kwargs.get("input_type") or self.source_input_type
) # assume source input type if not passed by `compute_query_embeddings`
return self.generate_embeddings(texts, input_type=input_type)

def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given texts
Expand All @@ -78,9 +131,10 @@ def generate_embeddings(
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
self._init_client()
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name)
rs = CohereEmbeddingFunction.client.embed(
texts=texts, model=self.name, **kwargs
)

return [emb for emb in rs.embeddings]

Expand Down

0 comments on commit 16eff25

Please sign in to comment.