Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add exact match caching #1717

Merged
merged 26 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8948a6e
prototype caching
ayulockin Nov 28, 2024
9c885d2
abstract away caching+unique keys
ayulockin Nov 28, 2024
55e7472
remove bad TODO
ayulockin Nov 28, 2024
da974a1
simplify
ayulockin Nov 30, 2024
7bb60a4
Merge branch 'explodinggradients:main' into ayulockin/cache
ayulockin Dec 1, 2024
3656778
add deps in test
ayulockin Dec 3, 2024
f14edac
Merge branch 'explodinggradients:main' into ayulockin/cache
ayulockin Dec 7, 2024
a41b583
caching on/off
ayulockin Dec 7, 2024
a74d73a
formatting
ayulockin Dec 7, 2024
ee56e8a
add tests
ayulockin Dec 7, 2024
b3ebb1f
pyright + windows test
ayulockin Dec 7, 2024
929deb8
Merge branch 'explodinggradients:main' into ayulockin/cache
ayulockin Dec 9, 2024
dc0248b
remove caching to base llm
ayulockin Dec 9, 2024
6c0e651
formatting
ayulockin Dec 9, 2024
6daa30f
Merge branch 'main' into ayulockin/cache
ayulockin Dec 9, 2024
aab887b
Merge branch 'explodinggradients:main' into ayulockin/cache
ayulockin Dec 10, 2024
27f5d40
Merge branch 'explodinggradients:main' into ayulockin/cache
ayulockin Dec 14, 2024
1c12986
make caching parameterized from the llm wrappers
ayulockin Dec 14, 2024
28fa1fb
caching support for embedding
ayulockin Dec 14, 2024
f8e9e61
cache mixin
ayulockin Dec 14, 2024
4884922
formatting
ayulockin Dec 14, 2024
905a072
Revert "cache mixin"
ayulockin Dec 14, 2024
7f5d399
formatting + test
ayulockin Dec 14, 2024
0a79a04
make llm caching implementation less ugly
ayulockin Dec 14, 2024
f84f300
langchain
ayulockin Dec 14, 2024
7299bd3
cleaner embedding
ayulockin Dec 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@ src/ragas/_version.py
.vscode
.envrc
uv.lock
.cache/
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"pydantic>=2",
"openai>1",
"pysbd>=0.3.4",
"diskcache>=5.6.3",
]
dynamic = ["version", "readme"]

Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest-xdist[psutil]
pytest-asyncio
llama_index
nbmake
diskcache
4 changes: 4 additions & 0 deletions src/ragas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ragas.cache import CacheInterface, DiskCacheBackend, cacher
from ragas.dataset_schema import EvaluationDataset, MultiTurnSample, SingleTurnSample
from ragas.evaluation import evaluate
from ragas.run_config import RunConfig
Expand All @@ -15,4 +16,7 @@
"SingleTurnSample",
"MultiTurnSample",
"EvaluationDataset",
"cacher",
"CacheInterface",
"DiskCacheBackend",
]
117 changes: 117 additions & 0 deletions src/ragas/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import functools
import hashlib
import inspect
import json
from abc import ABC, abstractmethod
from typing import Any, Optional

from pydantic import BaseModel


class CacheInterface(ABC):
@abstractmethod
def get(self, key: str) -> Any:
pass

@abstractmethod
def set(self, key: str, value) -> None:
pass

@abstractmethod
def has_key(self, key: str) -> bool:
pass


class DiskCacheBackend(CacheInterface):
def __init__(self, cache_dir: str = ".cache"):
try:
from diskcache import Cache
except ImportError:
raise ImportError(
"For using the diskcache backend, please install it with `pip install diskcache`."
)

self.cache = Cache(cache_dir)

def get(self, key: str) -> Any:
return self.cache.get(key)

def set(self, key: str, value) -> None:
self.cache.set(key, value)

def has_key(self, key: str) -> bool:
return key in self.cache

def __del__(self):
if hasattr(self, "cache"):
self.cache.close()


def _make_hashable(o):
if isinstance(o, (tuple, list)):
return tuple(_make_hashable(e) for e in o)
elif isinstance(o, dict):
return tuple(sorted((k, _make_hashable(v)) for k, v in o.items()))
elif isinstance(o, set):
return tuple(sorted(_make_hashable(e) for e in o))
elif isinstance(o, BaseModel):
return _make_hashable(o.model_dump())
else:
return o


EXCLUDE_KEYS = ["callbacks"]


def _generate_cache_key(func, args, kwargs):
if inspect.ismethod(func):
args = args[1:]

filtered_kwargs = {k: v for k, v in kwargs.items() if k not in EXCLUDE_KEYS}

key_data = {
"function": func.__qualname__,
"args": _make_hashable(args),
"kwargs": _make_hashable(filtered_kwargs),
}

key_string = json.dumps(key_data, sort_keys=True, default=str)
cache_key = hashlib.sha256(key_string.encode("utf-8")).hexdigest()
return cache_key


def cacher(cache_backend: Optional[CacheInterface] = None):
def decorator(func):
if cache_backend is None:
return func

# hack to make pyright happy
backend: CacheInterface = cache_backend

is_async = inspect.iscoroutinefunction(func)

@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
cache_key = _generate_cache_key(func, args, kwargs)

if backend.has_key(cache_key):
return backend.get(cache_key)

result = await func(*args, **kwargs)
backend.set(cache_key, result)
return result

@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
cache_key = _generate_cache_key(func, args, kwargs)

if backend.has_key(cache_key):
return backend.get(cache_key)

result = func(*args, **kwargs)
backend.set(cache_key, result)
return result

return async_wrapper if is_async else sync_wrapper

return decorator
40 changes: 37 additions & 3 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import typing as t
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import field
from typing import List

Expand All @@ -12,6 +12,7 @@
from pydantic.dataclasses import dataclass
from pydantic_core import CoreSchema, core_schema

from ragas.cache import CacheInterface, cacher
from ragas.run_config import RunConfig, add_async_retry, add_retry

if t.TYPE_CHECKING:
Expand All @@ -35,6 +36,20 @@ class BaseRagasEmbeddings(Embeddings, ABC):
"""

run_config: RunConfig
cache: t.Optional[CacheInterface] = None

def __init__(self, cache: t.Optional[CacheInterface] = None):
super().__init__()
self.cache = cache
if self.cache is not None:
self.embed_query = cacher(cache_backend=self.cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=self.cache)(
self.embed_documents
)
self.aembed_query = cacher(cache_backend=self.cache)(self.aembed_query)
self.aembed_documents = cacher(cache_backend=self.cache)(
self.aembed_documents
)

async def embed_text(self, text: str, is_async=True) -> List[float]:
"""
Expand All @@ -61,6 +76,12 @@ async def embed_texts(
)
return await loop.run_in_executor(None, embed_documents_with_retry, texts)

@abstractmethod
async def aembed_query(self, text: str) -> List[float]: ...

@abstractmethod
async def aembed_documents(self, texts: List[str]) -> t.List[t.List[float]]: ...

def set_run_config(self, run_config: RunConfig):
"""
Set the run configuration for the embedding operations.
Expand All @@ -85,8 +106,12 @@ class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
"""

def __init__(
self, embeddings: Embeddings, run_config: t.Optional[RunConfig] = None
self,
embeddings: Embeddings,
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None,
):
super().__init__(cache=cache)
self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
Expand Down Expand Up @@ -189,11 +214,13 @@ class HuggingfaceEmbeddings(BaseRagasEmbeddings):
cache_folder: t.Optional[str] = None
model_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
encode_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
cache: t.Optional[CacheInterface] = None

def __post_init__(self):
"""
Initialize the model after the object is created.
"""
super().__init__(cache=self.cache)
try:
import sentence_transformers
from transformers import AutoConfig
Expand Down Expand Up @@ -226,6 +253,9 @@ def __post_init__(self):
if "convert_to_tensor" not in self.encode_kwargs:
self.encode_kwargs["convert_to_tensor"] = True

if self.cache is not None:
self.predict = cacher(cache_backend=self.cache)(self.predict)

def embed_query(self, text: str) -> List[float]:
"""
Embed a single query text.
Expand Down Expand Up @@ -297,8 +327,12 @@ class LlamaIndexEmbeddingsWrapper(BaseRagasEmbeddings):
"""

def __init__(
self, embeddings: BaseEmbedding, run_config: t.Optional[RunConfig] = None
self,
embeddings: BaseEmbedding,
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None,
):
super().__init__(cache=cache)
self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
Expand Down
12 changes: 12 additions & 0 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langchain_openai.llms import AzureOpenAI, OpenAI
from langchain_openai.llms.base import BaseOpenAI

from ragas.cache import CacheInterface, cacher
from ragas.exceptions import LLMDidNotFinishException
from ragas.integrations.helicone import helicone_config
from ragas.run_config import RunConfig, add_async_retry
Expand Down Expand Up @@ -47,6 +48,13 @@ def is_multiple_completion_supported(llm: BaseLanguageModel) -> bool:
class BaseRagasLLM(ABC):
run_config: RunConfig = field(default_factory=RunConfig, repr=False)
multiple_completion_supported: bool = field(default=False, repr=False)
cache: t.Optional[CacheInterface] = field(default=None, repr=False)

def __post_init__(self):
# If a cache_backend is provided, wrap the implementation methods at construction time.
if self.cache is not None:
self.generate_text = cacher(cache_backend=self.cache)(self.generate_text)
self.agenerate_text = cacher(cache_backend=self.cache)(self.agenerate_text)

def set_run_config(self, run_config: RunConfig):
self.run_config = run_config
Expand Down Expand Up @@ -124,7 +132,9 @@ def __init__(
langchain_llm: BaseLanguageModel,
run_config: t.Optional[RunConfig] = None,
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
cache: t.Optional[CacheInterface] = None,
):
super().__init__(cache=cache)
self.langchain_llm = langchain_llm
if run_config is None:
run_config = RunConfig()
Expand Down Expand Up @@ -273,7 +283,9 @@ def __init__(
self,
llm: BaseLLM,
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None,
):
super().__init__(cache=cache)
self.llm = llm

try:
Expand Down
1 change: 1 addition & 0 deletions src/ragas/prompt/pydantic_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ async def generate_multiple(
If there's an error parsing the output.
"""
callbacks = callbacks or []

processed_data = self.process_input(data)
prompt_rm, prompt_cb = new_group(
name=self.name,
Expand Down
Loading
Loading