-
Notifications
You must be signed in to change notification settings - Fork 235
/
Copy pathcohere.py
78 lines (68 loc) · 2.72 KB
/
cohere.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from typing import Any, List, Optional
from pydantic import PrivateAttr
from semantic_router.encoders import DenseEncoder
from semantic_router.utils.defaults import EncoderDefault
class CohereEncoder(DenseEncoder):
_client: Any = PrivateAttr()
_embed_type: Any = PrivateAttr()
type: str = "cohere"
input_type: Optional[str] = "search_query"
def __init__(
self,
name: Optional[str] = None,
cohere_api_key: Optional[str] = None,
score_threshold: float = 0.3,
input_type: Optional[str] = "search_query",
):
if name is None:
name = EncoderDefault.COHERE.value["embedding_model"]
super().__init__(
name=name,
score_threshold=score_threshold,
input_type=input_type, # type: ignore
)
self.input_type = input_type
self._client = self._initialize_client(cohere_api_key)
def _initialize_client(self, cohere_api_key: Optional[str] = None):
"""Initializes the Cohere client.
:param cohere_api_key: The API key for the Cohere client, can also
be set via the COHERE_API_KEY environment variable.
:return: An instance of the Cohere client.
"""
try:
import cohere
from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse
self._embed_type = EmbeddingsByTypeEmbedResponse
except ImportError:
raise ImportError(
"Please install Cohere to use CohereEncoder. "
"You can install it with: "
"`pip install 'semantic-router[cohere]'`"
)
cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
if cohere_api_key is None:
raise ValueError("Cohere API key cannot be 'None'.")
try:
client = cohere.Client(cohere_api_key)
except Exception as e:
raise ValueError(
f"Cohere API client failed to initialize. Error: {e}"
) from e
return client
def __call__(self, docs: List[str]) -> List[List[float]]:
if self._client is None:
raise ValueError("Cohere client is not initialized.")
try:
embeds = self._client.embed(
texts=docs, input_type=self.input_type, model=self.name
)
# Check for unsupported type.
if isinstance(embeds, self._embed_type):
raise NotImplementedError(
"Handling of EmbedByTypeResponseEmbeddings is not implemented."
)
else:
return embeds.embeddings
except Exception as e:
raise ValueError(f"Cohere API call failed. Error: {e}") from e