Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add normalize to rerank model #2509

Merged
merged 4 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class RerankRequest(BaseModel):
return_documents: Optional[bool] = False
return_len: Optional[bool] = False
max_chunks_per_doc: Optional[int] = None
kwargs: Optional[str] = None


class TextToImageRequest(BaseModel):
Expand Down Expand Up @@ -1315,11 +1316,6 @@ async def rerank(self, request: Request) -> Response:
payload = await request.json()
body = RerankRequest.parse_obj(payload)
model_uid = body.model
kwargs = {
key: value
for key, value in payload.items()
if key not in RerankRequest.__annotations__.keys()
}

try:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
Expand All @@ -1333,14 +1329,18 @@ async def rerank(self, request: Request) -> Response:
raise HTTPException(status_code=500, detail=str(e))

try:
if body.kwargs is not None:
parsed_kwargs = json.loads(body.kwargs)
else:
parsed_kwargs = {}
scores = await model.rerank(
body.documents,
body.query,
top_n=body.top_n,
max_chunks_per_doc=body.max_chunks_per_doc,
return_documents=body.return_documents,
return_len=body.return_len,
**kwargs,
**parsed_kwargs,
)
return Response(scores, media_type="application/json")
except RuntimeError as re:
Expand Down
1 change: 1 addition & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def rerank(
"max_chunks_per_doc": max_chunks_per_doc,
"return_documents": return_documents,
"return_len": return_len,
"kwargs": json.dumps(kwargs),
}
request_body.update(kwargs)
response = requests.post(url, json=request_body, headers=self.auth_headers)
Expand Down
15 changes: 11 additions & 4 deletions xinference/model/rerank/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _auto_detect_type(model_path):
return rerank_type

def load(self):
logger.info("Loading rerank model: %s", self._model_path)
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
if (
self._auto_detect_type(self._model_path) != "normal"
Expand All @@ -189,6 +190,7 @@ def load(self):
"will force set `use_fp16` to True"
)
self._use_fp16 = True

if self._model_spec.type == "normal":
try:
import sentence_transformers
Expand Down Expand Up @@ -250,22 +252,27 @@ def rerank(
**kwargs,
) -> Rerank:
assert self._model is not None
if kwargs:
raise ValueError("rerank hasn't support extra parameter.")
if max_chunks_per_doc is not None:
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
sentence_combinations = [[query, doc] for doc in documents]
# reset n tokens
self._model.model.n_tokens = 0
if self._model_spec.type == "normal":
similarity_scores = self._model.predict(
sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
sentence_combinations,
convert_to_numpy=False,
convert_to_tensor=True,
**kwargs,
).cpu()
if similarity_scores.dtype == torch.bfloat16:
similarity_scores = similarity_scores.float()
else:
# Related issue: https://github.com/xorbitsai/inference/issues/1775
similarity_scores = self._model.compute_score(sentence_combinations)
similarity_scores = self._model.compute_score(
sentence_combinations, **kwargs
)

if not isinstance(similarity_scores, Sequence):
similarity_scores = [similarity_scores]
elif (
Expand Down
5 changes: 2 additions & 3 deletions xinference/model/rerank/tests/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ def test_restful_api(model_name, setup):
kwargs = {
"invalid": "invalid",
}
with pytest.raises(RuntimeError) as err:
scores = model.rerank(corpus, query, **kwargs)
assert "hasn't support" in str(err.value)
with pytest.raises(RuntimeError):
model.rerank(corpus, query, **kwargs)


def test_from_local_uri():
Expand Down
Loading