diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index d9d4678911..d8b4564106 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -115,6 +115,7 @@ class RerankRequest(BaseModel): return_documents: Optional[bool] = False return_len: Optional[bool] = False max_chunks_per_doc: Optional[int] = None + kwargs: Optional[str] = None class TextToImageRequest(BaseModel): @@ -1315,11 +1316,6 @@ async def rerank(self, request: Request) -> Response: payload = await request.json() body = RerankRequest.parse_obj(payload) model_uid = body.model - kwargs = { - key: value - for key, value in payload.items() - if key not in RerankRequest.__annotations__.keys() - } try: model = await (await self._get_supervisor_ref()).get_model(model_uid) @@ -1333,6 +1329,10 @@ async def rerank(self, request: Request) -> Response: raise HTTPException(status_code=500, detail=str(e)) try: + if body.kwargs is not None: + parsed_kwargs = json.loads(body.kwargs) + else: + parsed_kwargs = {} scores = await model.rerank( body.documents, body.query, @@ -1340,7 +1340,7 @@ async def rerank(self, request: Request) -> Response: max_chunks_per_doc=body.max_chunks_per_doc, return_documents=body.return_documents, return_len=body.return_len, - **kwargs, + **parsed_kwargs, ) return Response(scores, media_type="application/json") except RuntimeError as re: diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index ed71a7bf05..ab03c566c1 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -174,6 +174,7 @@ def rerank( "max_chunks_per_doc": max_chunks_per_doc, "return_documents": return_documents, "return_len": return_len, + "kwargs": json.dumps(kwargs), } request_body.update(kwargs) response = requests.post(url, json=request_body, headers=self.auth_headers) diff --git a/xinference/model/rerank/core.py b/xinference/model/rerank/core.py index ffd1485eb5..4eb071ecd1 100644 --- a/xinference/model/rerank/core.py +++ b/xinference/model/rerank/core.py @@ -179,6 +179,7 @@ def _auto_detect_type(model_path): return rerank_type def load(self): + logger.info("Loading rerank model: %s", self._model_path) flash_attn_installed = importlib.util.find_spec("flash_attn") is not None if ( self._auto_detect_type(self._model_path) != "normal" @@ -189,6 +190,7 @@ def load(self): "will force set `use_fp16` to True" ) self._use_fp16 = True + if self._model_spec.type == "normal": try: import sentence_transformers @@ -250,22 +252,27 @@ def rerank( **kwargs, ) -> Rerank: assert self._model is not None - if kwargs: - raise ValueError("rerank hasn't support extra parameter.") if max_chunks_per_doc is not None: raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.") + logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model) sentence_combinations = [[query, doc] for doc in documents] # reset n tokens self._model.model.n_tokens = 0 if self._model_spec.type == "normal": similarity_scores = self._model.predict( - sentence_combinations, convert_to_numpy=False, convert_to_tensor=True + sentence_combinations, + convert_to_numpy=False, + convert_to_tensor=True, + **kwargs, ).cpu() if similarity_scores.dtype == torch.bfloat16: similarity_scores = similarity_scores.float() else: # Related issue: https://github.com/xorbitsai/inference/issues/1775 - similarity_scores = self._model.compute_score(sentence_combinations) + similarity_scores = self._model.compute_score( + sentence_combinations, **kwargs + ) + if not isinstance(similarity_scores, Sequence): similarity_scores = [similarity_scores] elif ( diff --git a/xinference/model/rerank/tests/test_rerank.py b/xinference/model/rerank/tests/test_rerank.py index 4ceac1c811..b76572dbad 100644 --- a/xinference/model/rerank/tests/test_rerank.py +++ b/xinference/model/rerank/tests/test_rerank.py @@ -118,9 +118,8 @@ def test_restful_api(model_name, setup): kwargs = { "invalid": "invalid", } - with pytest.raises(RuntimeError) as err: - scores = model.rerank(corpus, query, **kwargs) - assert "hasn't support" in str(err.value) + with pytest.raises(RuntimeError): + model.rerank(corpus, query, **kwargs) def test_from_local_uri():