Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bclavie committed Dec 29, 2023
0 parents commit 4925fd4
Show file tree
Hide file tree
Showing 18 changed files with 4,394 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/FUNDING.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
github: bclavie
121 changes: 121 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so


# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py
.pdm.toml
__pypackages__/

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
.envrc

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/


.mypy.ipynb_checkpoints
.mkdocs.yml


archive/
52 changes: 52 additions & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]

# Same as Black.
line-length = 88
output-format = "grouped"

target-version = "py39"

[lint]
select = [
# bugbear rules
"B",
# remove unused imports
"F401",
# bare except statements
"E722",
# unused arguments
"ARG",
]
ignore = [
"B006",
"B018",
]

unfixable = [
"T201",
"T203",
]
ignore-init-module-imports = true
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Welcome to RAGatouille
3,933 changes: 3,933 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[tool.poetry]
name = "RAGatouille"
version = "0.0.0a"
description = "Library to facilitate the use of late-interaction retrieval models in common RAG contexts."
authors = ["Benjamin Clavie <ben@clavie.eu>"]
readme = "README.md"
packages = [{include = "ragatouille"}]
repository = "https://github.com/bclavie/ragatouille"

[tool.poetry.dependencies]
python = "^3.9"
ruff = "^0.1.9"
faiss-cpu = "^1.7.4"
transformers = "^4.36.2"
voyager = "^2.0.2"
sentence-transformers = "^2.2.2"
torch = "^2.1.2"
colbert-ir = {git = "https://github.com/stanford-futuredata/ColBERT.git"}

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
mkdocs = "^1.4.3"
mkdocs-material = "^9.1.18"
mkdocstrings = "^0.22.0"
mkdocstrings-python = "^1.1.2"
ruff = "^0.1.9"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Empty file added ragatouille/__init__.py
Empty file.
Empty file added ragatouille/data/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions ragatouille/data/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pydantic import BaseModel


class TrainingTriplet(BaseModel):
"""
A training triplet.
"""

anchor: str
positive: str
negative: str


class QueryPassages(BaseModel):
"""
A query and a list of passages.
"""

query: str
positive_passages: list[str]
negative_passages: list[str]
8 changes: 8 additions & 0 deletions ragatouille/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def make_triplets_from_labels():
"""Make triplets from binary label sentence pairs."""
pass


def make_triplets_from_lists():
"""Make triplets from lists of positives/negatives for query."""
pass
Empty file added ragatouille/models/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Union
from pathlib import Path
from colbert import Run, ColBERTConfig, Indexer, RunConfig
import torch


class ColBERT:
def __init__(
self,
pretrained_model_name_or_path: Union[str, Path],
n_gpu: int = -1,
**kwargs,
):
if n_gpu == -1:
n_gpu = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count()
run_config = RunConfig(
nranks=n_gpu, experiment="colbert", root="/.ragatouille/"
)
self.run_context = Run().context(run_config)
self.run_context.__enter__() # Manually enter the context
self.checkpoint = pretrained_model_name_or_path
ckpt_config = ColBERTConfig.load_from_checkpoint(self.checkpoint)
local_config = ColBERTConfig(**kwargs)
self.config = ColBERTConfig.from_existing(
ckpt_config,
local_config,
)

def train():
pass

def index(self, name, collection):
pass
self.indexer = Indexer(checkpoint="/path/to/checkpoint", config=self.config)
self.indexer.index(name=name, collection=collection)

def search(self, name, query):
pass

def __del__(self):
# Clean up the context if needed
self.run_context.__exit__(None, None, None)
Empty file.
19 changes: 19 additions & 0 deletions ragatouille/negative_miners/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any


class HardNegativeMiner(ABC):
@abstractmethod
def get_name(self) -> str:
...

@abstractmethod
def build_index(
self, collection: list[str], batch_size: int, save_index: bool, path: str | Path
) -> Any:
...

@abstractmethod
def export_index(self, path: str | Path) -> bool:
...
Loading

0 comments on commit 4925fd4

Please sign in to comment.