Skip to content

Commit

Permalink
ENH: Improve error message (#2738)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Jan 10, 2025
1 parent 4b31c22 commit 1d379a1
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 84 deletions.
110 changes: 45 additions & 65 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,19 @@ async def terminate_model(self, model_uid: str) -> JSONResponse:
async def get_address(self) -> JSONResponse:
return JSONResponse(content=self._supervisor_address)

async def _get_model_last_error(self, replica_model_uid: bytes, e: Exception):
if not isinstance(e, xo.ServerClosed):
return e
try:
model_status = await (await self._get_supervisor_ref()).get_model_status(
replica_model_uid.decode("utf-8")
)
if model_status is not None and model_status.last_error:
return Exception(model_status.last_error)
except Exception as ex:
return ex
return e

async def create_completion(self, request: Request) -> Response:
raw_body = await request.json()
body = CreateCompletionRequest.parse_obj(raw_body)
Expand Down Expand Up @@ -1272,6 +1285,7 @@ async def stream_results():
)
return
except Exception as ex:
ex = await self._get_model_last_error(model.uid, ex)
logger.exception("Completion stream got an error: %s", ex)
await self._report_error_event(model_uid, str(ex))
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
Expand All @@ -1286,6 +1300,7 @@ async def stream_results():
data = await model.generate(body.prompt, kwargs, raw_params=raw_kwargs)
return Response(data, media_type="application/json")
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
Expand Down Expand Up @@ -1317,14 +1332,11 @@ async def create_embedding(self, request: Request) -> Response:
try:
embedding = await model.create_embedding(body.input, **kwargs)
return Response(embedding, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def convert_ids_to_tokens(self, request: Request) -> Response:
Expand Down Expand Up @@ -1352,14 +1364,11 @@ async def convert_ids_to_tokens(self, request: Request) -> Response:
try:
decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs)
return Response(decoded_texts, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def rerank(self, request: Request) -> Response:
Expand Down Expand Up @@ -1393,14 +1402,11 @@ async def rerank(self, request: Request) -> Response:
**parsed_kwargs,
)
return Response(scores, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_transcriptions(
Expand Down Expand Up @@ -1445,13 +1451,11 @@ async def create_transcriptions(
**parsed_kwargs,
)
return Response(content=transcription, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model_ref.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_translations(
Expand Down Expand Up @@ -1496,13 +1500,11 @@ async def create_translations(
**parsed_kwargs,
)
return Response(content=translation, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model_ref.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_speech(
Expand Down Expand Up @@ -1558,14 +1560,11 @@ async def stream_results():
)
else:
return Response(media_type="application/octet-stream", content=out)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def get_progress(self, request_id: str) -> JSONResponse:
Expand Down Expand Up @@ -1611,14 +1610,11 @@ async def create_images(self, request: Request) -> Response:
logger.error(err_str)
await self._report_error_event(model_uid, err_str)
raise HTTPException(status_code=409, detail=err_str)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def sdapi_options(self, request: Request) -> Response:
Expand Down Expand Up @@ -1689,14 +1685,11 @@ async def sdapi_txt2img(self, request: Request) -> Response:
**kwargs,
)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def sdapi_img2img(self, request: Request) -> Response:
Expand All @@ -1723,14 +1716,11 @@ async def sdapi_img2img(self, request: Request) -> Response:
**kwargs,
)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_variations(
Expand Down Expand Up @@ -1779,13 +1769,11 @@ async def create_variations(
logger.error(err_str)
await self._report_error_event(model_uid, err_str)
raise HTTPException(status_code=409, detail=err_str)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model_ref.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_inpainting(
Expand Down Expand Up @@ -1841,13 +1829,11 @@ async def create_inpainting(
logger.error(err_str)
await self._report_error_event(model_uid, err_str)
raise HTTPException(status_code=409, detail=err_str)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model_ref.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_ocr(
Expand Down Expand Up @@ -1887,13 +1873,11 @@ async def create_ocr(
logger.error(err_str)
await self._report_error_event(model_uid, err_str)
raise HTTPException(status_code=409, detail=err_str)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model_ref.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_flexible_infer(self, request: Request) -> Response:
Expand All @@ -1920,14 +1904,11 @@ async def create_flexible_infer(self, request: Request) -> Response:
try:
result = await model.infer(**kwargs)
return Response(result, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_videos(self, request: Request) -> Response:
Expand All @@ -1952,14 +1933,11 @@ async def create_videos(self, request: Request) -> Response:
**kwargs,
)
return Response(content=video_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def create_chat_completion(self, request: Request) -> Response:
Expand Down Expand Up @@ -2084,6 +2062,7 @@ async def stream_results():
# TODO: Cannot yield here. Yield here would leads to error for the next streaming request.
return
except Exception as ex:
ex = await self._get_model_last_error(model.uid, ex)
logger.exception("Chat completion stream got an error: %s", ex)
await self._report_error_event(model_uid, str(ex))
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
Expand All @@ -2102,6 +2081,7 @@ async def stream_results():
)
return Response(content=data, media_type="application/json")
except Exception as e:
e = await self._get_model_last_error(model.uid, e)
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
Expand Down
26 changes: 26 additions & 0 deletions xinference/client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,13 @@ def set_auto_recover_limit():
del os.environ["XINFERENCE_MODEL_ACTOR_AUTO_RECOVER_LIMIT"]


@pytest.fixture
def set_test_oom_error():
os.environ["XINFERENCE_TEST_OUT_OF_MEMORY_ERROR"] = "1"
yield
del os.environ["XINFERENCE_TEST_OUT_OF_MEMORY_ERROR"]


@pytest.fixture
def setup_cluster():
import xoscar as xo
Expand Down Expand Up @@ -488,3 +495,22 @@ def test_auto_recover(set_auto_recover_limit, setup_cluster):
time.sleep(1)
else:
assert False


def test_model_error(set_test_oom_error, setup_cluster):
endpoint, _ = setup_cluster
client = RESTfulClient(endpoint)

model_uid = client.launch_model(
model_name="qwen1.5-chat",
model_engine="llama.cpp",
model_size_in_billions="0_5",
quantization="q4_0",
)
assert len(client.list_models()) == 1

model = client.get_model(model_uid=model_uid)
assert isinstance(model, RESTfulChatModelHandle)

with pytest.raises(RuntimeError, match="Model actor is out of memory"):
model.generate("Once upon a time, there was a very old computer")
Loading

0 comments on commit 1d379a1

Please sign in to comment.