From e5594a5d09c172e9a6e6870a65a1e6425ad50aaf Mon Sep 17 00:00:00 2001 From: Weizheng Lu Date: Sun, 10 Nov 2024 23:05:45 +0800 Subject: [PATCH 1/7] DOC: Add paper citation (#2533) --- README.md | 18 ++++++++++++++++++ README_ja_JP.md | 18 ++++++++++++++++++ README_zh_CN.md | 18 ++++++++++++++++++ 3 files changed, 54 insertions(+) diff --git a/README.md b/README.md index d63cbb42ff..d705bdb9d7 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,24 @@ Once Xinference is running, there are multiple ways you can try it: via the web | [Slack](https://join.slack.com/t/xorbitsio/shared_invite/zt-1o3z9ucdh-RbfhbPVpx7prOVdM1CAuxg) | Collaborating with other Xorbits users. | | [Twitter](https://twitter.com/xorbitsio) | Staying up-to-date on new features. | +## Citation + +If this work is helpful, please kindly cite as: + +```bibtex +@inproceedings{lu2024xinference, + title = "Xinference: Making Large Model Serving Easy", + author = "Lu, Weizheng and Xiong, Lingfeng and Zhang, Feng and Qin, Xuye and Chen, Yueguo", + booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing: System Demonstrations", + month = nov, + year = "2024", + address = "Miami, Florida, USA", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.emnlp-demo.30", + pages = "291--300", +} +``` + ## Contributors diff --git a/README_ja_JP.md b/README_ja_JP.md index c80601a9c7..ff3f0e4861 100644 --- a/README_ja_JP.md +++ b/README_ja_JP.md @@ -104,6 +104,24 @@ Xinferenceが実行されると、Web UI、cURL、コマンドライン、また | [Slack](https://join.slack.com/t/xorbitsio/shared_invite/zt-1o3z9ucdh-RbfhbPVpx7prOVdM1CAuxg) | 他のXorbitsユーザーとの協力。 | | [Twitter](https://twitter.com/xorbitsio) | 新機能に関する最新情報の入手。 | +## 引用 + +この仕事が役立つ場合は、以下のように引用してください: + +```bibtex +@inproceedings{lu2024xinference, + title = "Xinference: Making Large Model Serving Easy", + author = "Lu, Weizheng and Xiong, Lingfeng and Zhang, Feng and Qin, Xuye and Chen, Yueguo", + booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing: System Demonstrations", + month = nov, + year = "2024", + address = "Miami, Florida, USA", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.emnlp-demo.30", + pages = "291--300", +} +``` + ## 寄稿者 diff --git a/README_zh_CN.md b/README_zh_CN.md index 2df28e2632..b75bb6e905 100644 --- a/README_zh_CN.md +++ b/README_zh_CN.md @@ -164,6 +164,24 @@ $ xinference-local | [微信社群](https://xorbits.cn/assets/images/wechat_work_qr.png) | 与其他 Xorbits 用户交流。 | | [知乎](https://zhihu.com/org/xorbits) | 了解团队最新的进展。 | +## 引用 + +如果您觉得此项目有帮助,请以如下格式引用我们: + +```bibtex +@inproceedings{lu2024xinference, + title = "Xinference: Making Large Model Serving Easy", + author = "Lu, Weizheng and Xiong, Lingfeng and Zhang, Feng and Qin, Xuye and Chen, Yueguo", + booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing: System Demonstrations", + month = nov, + year = "2024", + address = "Miami, Florida, USA", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.emnlp-demo.30", + pages = "291--300", +} +``` + ## 贡献者 From fe945526cb3585e764eaf0db83884d50b95754d7 Mon Sep 17 00:00:00 2001 From: codingl2k1 <138426806+codingl2k1@users.noreply.github.com> Date: Tue, 12 Nov 2024 03:38:22 +0100 Subject: [PATCH 2/7] FEAT: Basic cancel support for image model (#2528) --- xinference/api/restful_api.py | 61 ++++++++++++++-- xinference/client/restful/restful_client.py | 9 ++- xinference/constants.py | 1 + xinference/core/model.py | 13 +++- xinference/core/supervisor.py | 10 ++- xinference/core/utils.py | 69 ++++++++++++++++++- .../image/tests/test_stable_diffusion.py | 58 ++++++++++++++++ 7 files changed, 207 insertions(+), 14 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index ed3a2eab90..d9d4678911 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -52,10 +52,14 @@ from .._compat import BaseModel, Field from .._version import get_versions -from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DISABLE_METRICS +from ..constants import ( + XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + XINFERENCE_DEFAULT_ENDPOINT_PORT, + XINFERENCE_DISABLE_METRICS, +) from ..core.event import Event, EventCollectorActor, EventType from ..core.supervisor import SupervisorActor -from ..core.utils import json_dumps +from ..core.utils import CancelMixin, json_dumps from ..types import ( ChatCompletion, Completion, @@ -206,7 +210,7 @@ class BuildGradioImageInterfaceRequest(BaseModel): model_ability: List[str] -class RESTfulAPI: +class RESTfulAPI(CancelMixin): def __init__( self, supervisor_address: str, @@ -1531,8 +1535,11 @@ async def create_images(self, request: Request) -> Response: await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + request_id = None try: kwargs = json.loads(body.kwargs) if body.kwargs else {} + request_id = kwargs.get("request_id") + self._add_running_task(request_id) image_list = await model.text_to_image( prompt=body.prompt, n=body.n, @@ -1541,6 +1548,11 @@ async def create_images(self, request: Request) -> Response: **kwargs, ) return Response(content=image_list, media_type="application/json") + except asyncio.CancelledError: + err_str = f"The request has been cancelled: {request_id}" + logger.error(err_str) + await self._report_error_event(model_uid, err_str) + raise HTTPException(status_code=409, detail=err_str) except RuntimeError as re: logger.error(re, exc_info=True) await self._report_error_event(model_uid, str(re)) @@ -1686,11 +1698,14 @@ async def create_variations( await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + request_id = None try: if kwargs is not None: parsed_kwargs = json.loads(kwargs) else: parsed_kwargs = {} + request_id = parsed_kwargs.get("request_id") + self._add_running_task(request_id) image_list = await model_ref.image_to_image( image=Image.open(image.file), prompt=prompt, @@ -1701,6 +1716,11 @@ async def create_variations( **parsed_kwargs, ) return Response(content=image_list, media_type="application/json") + except asyncio.CancelledError: + err_str = f"The request has been cancelled: {request_id}" + logger.error(err_str) + await self._report_error_event(model_uid, err_str) + raise HTTPException(status_code=409, detail=err_str) except RuntimeError as re: logger.error(re, exc_info=True) await self._report_error_event(model_uid, str(re)) @@ -1734,11 +1754,14 @@ async def create_inpainting( await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + request_id = None try: if kwargs is not None: parsed_kwargs = json.loads(kwargs) else: parsed_kwargs = {} + request_id = parsed_kwargs.get("request_id") + self._add_running_task(request_id) im = Image.open(image.file) mask_im = Image.open(mask_image.file) if not size: @@ -1755,6 +1778,11 @@ async def create_inpainting( **parsed_kwargs, ) return Response(content=image_list, media_type="application/json") + except asyncio.CancelledError: + err_str = f"The request has been cancelled: {request_id}" + logger.error(err_str) + await self._report_error_event(model_uid, err_str) + raise HTTPException(status_code=409, detail=err_str) except RuntimeError as re: logger.error(re, exc_info=True) await self._report_error_event(model_uid, str(re)) @@ -1782,17 +1810,25 @@ async def create_ocr( await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + request_id = None try: if kwargs is not None: parsed_kwargs = json.loads(kwargs) else: parsed_kwargs = {} + request_id = parsed_kwargs.get("request_id") + self._add_running_task(request_id) im = Image.open(image.file) text = await model_ref.ocr( image=im, **parsed_kwargs, ) return Response(content=text, media_type="text/plain") + except asyncio.CancelledError: + err_str = f"The request has been cancelled: {request_id}" + logger.error(err_str) + await self._report_error_event(model_uid, err_str) + raise HTTPException(status_code=409, detail=err_str) except RuntimeError as re: logger.error(re, exc_info=True) await self._report_error_event(model_uid, str(re)) @@ -2111,10 +2147,25 @@ async def get_model_events(self, model_uid: str) -> JSONResponse: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) - async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse: + async def abort_request( + self, request: Request, model_uid: str, request_id: str + ) -> JSONResponse: try: + payload = await request.json() + block_duration = payload.get( + "block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION + ) + logger.info( + "Abort request with model uid: %s, request id: %s, block duration: %s", + model_uid, + request_id, + block_duration, + ) supervisor_ref = await self._get_supervisor_ref() - res = await supervisor_ref.abort_request(model_uid, request_id) + res = await supervisor_ref.abort_request( + model_uid, request_id, block_duration + ) + self._cancel_running_task(request_id, block_duration) return JSONResponse(content=res) except Exception as e: logger.error(e, exc_info=True) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index dd5e3f1146..ed71a7bf05 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -1357,7 +1357,7 @@ def query_engine_by_model_name(self, model_name: str): response_data = response.json() return response_data - def abort_request(self, model_uid: str, request_id: str): + def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30): """ Abort a request. Abort a submitted request. If the request is finished or not found, this method will be a no-op. @@ -1369,13 +1369,18 @@ def abort_request(self, model_uid: str, request_id: str): Model uid. request_id: str Request id. + block_duration: int + The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may + prevent it from taking effect if it arrives before the request operation. Returns ------- Dict Return empty dict. """ url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort" - response = requests.post(url, headers=self._headers) + response = requests.post( + url, headers=self._headers, json={"block_duration": block_duration} + ) if response.status_code != 200: raise RuntimeError( f"Failed to abort request, detail: {_get_error_string(response)}" diff --git a/xinference/constants.py b/xinference/constants.py index 93e73e4d58..dd0adcb864 100644 --- a/xinference/constants.py +++ b/xinference/constants.py @@ -88,3 +88,4 @@ def get_xinference_home() -> str: XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None ) XINFERENCE_LAUNCH_MODEL_RETRY = 3 +XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION = 30 diff --git a/xinference/core/model.py b/xinference/core/model.py index e911c71e6d..42453ddc69 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -41,6 +41,7 @@ import xoscar as xo from ..constants import ( + XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, XINFERENCE_LAUNCH_MODEL_RETRY, XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE, ) @@ -57,7 +58,7 @@ logger = logging.getLogger(__name__) from ..device_utils import empty_cache -from .utils import json_dumps, log_async +from .utils import CancelMixin, json_dumps, log_async try: from torch.cuda import OutOfMemoryError @@ -136,7 +137,7 @@ async def _async_wrapper(*args, **kwargs): return _wrapper -class ModelActor(xo.StatelessActor): +class ModelActor(xo.StatelessActor, CancelMixin): _replica_model_uid: Optional[str] @classmethod @@ -553,6 +554,7 @@ async def _call_wrapper_binary(self, fn: Callable, *args, **kwargs): @oom_check async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs): + self._add_running_task(kwargs.get("request_id")) if self._lock is None: if inspect.iscoroutinefunction(fn): ret = await fn(*args, **kwargs) @@ -761,9 +763,14 @@ async def chat(self, messages: List[Dict], *args, **kwargs): prompt_tokens, ) - async def abort_request(self, request_id: str) -> str: + async def abort_request( + self, + request_id: str, + block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + ) -> str: from .utils import AbortRequestMessage + self._cancel_running_task(request_id, block_duration) if self.allow_batching(): if self._scheduler_ref is None: return AbortRequestMessage.NOT_FOUND.name diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 8f705217a3..c8f2f59ff6 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -35,6 +35,7 @@ import xoscar as xo from ..constants import ( + XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, XINFERENCE_DISABLE_HEALTH_CHECK, XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD, XINFERENCE_HEALTH_CHECK_INTERVAL, @@ -1213,7 +1214,12 @@ async def list_cached_models( return cached_models @log_async(logger=logger) - async def abort_request(self, model_uid: str, request_id: str) -> Dict: + async def abort_request( + self, + model_uid: str, + request_id: str, + block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + ) -> Dict: from .scheduler import AbortRequestMessage res = {"msg": AbortRequestMessage.NO_OP.name} @@ -1228,7 +1234,7 @@ async def abort_request(self, model_uid: str, request_id: str) -> Dict: if worker_ref is None: continue model_ref = await worker_ref.get_model(model_uid=rep_mid) - result_info = await model_ref.abort_request(request_id) + result_info = await model_ref.abort_request(request_id, block_duration) res["msg"] = result_info if result_info == AbortRequestMessage.DONE.name: break diff --git a/xinference/core/utils.py b/xinference/core/utils.py index d4caba8c54..278c570b20 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import os import random import string import uuid +import weakref from enum import Enum from typing import Dict, Generator, List, Optional, Tuple, Union @@ -23,7 +25,10 @@ from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown from .._compat import BaseModel -from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH +from ..constants import ( + XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + XINFERENCE_LOG_ARG_MAX_LENGTH, +) logger = logging.getLogger(__name__) @@ -49,13 +54,20 @@ def log_async( ): import time from functools import wraps + from inspect import signature def decorator(func): func_name = func.__name__ + sig = signature(func) @wraps(func) async def wrapped(*args, **kwargs): - request_id_str = kwargs.get("request_id", "") + try: + bound_args = sig.bind_partial(*args, **kwargs) + arguments = bound_args.arguments + except TypeError: + arguments = {} + request_id_str = arguments.get("request_id", "") if not request_id_str: request_id_str = uuid.uuid1() if func_name == "text_to_image": @@ -269,3 +281,56 @@ def assign_replica_gpu( if isinstance(gpu_idx, list) and gpu_idx: return gpu_idx[rep_id::replica] return gpu_idx + + +class CancelMixin: + _CANCEL_TASK_NAME = "abort_block" + + def __init__(self): + self._running_tasks: weakref.WeakValueDictionary[ + str, asyncio.Task + ] = weakref.WeakValueDictionary() + + def _add_running_task(self, request_id: Optional[str]): + """Add current asyncio task to the running task. + :param request_id: The corresponding request id. + """ + if request_id is None: + return + running_task = self._running_tasks.get(request_id) + if running_task is not None: + if running_task.get_name() == self._CANCEL_TASK_NAME: + raise Exception(f"The request has been aborted: {request_id}") + raise Exception(f"Duplicate request id: {request_id}") + current_task = asyncio.current_task() + assert current_task is not None + self._running_tasks[request_id] = current_task + + def _cancel_running_task( + self, + request_id: Optional[str], + block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + ): + """Cancel the running asyncio task. + :param request_id: The request id to cancel. + :param block_duration: The duration seconds to ensure the request can't be executed. + """ + if request_id is None: + return + running_task = self._running_tasks.pop(request_id, None) + if running_task is not None: + running_task.cancel() + + async def block_task(): + """This task is for blocking the request for a duration.""" + try: + await asyncio.sleep(block_duration) + logger.info("Abort block end for request: %s", request_id) + except asyncio.CancelledError: + logger.info("Abort block is cancelled for request: %s", request_id) + + if block_duration > 0: + logger.info("Abort block start for request: %s", request_id) + self._running_tasks[request_id] = asyncio.create_task( + block_task(), name=self._CANCEL_TASK_NAME + ) diff --git a/xinference/model/image/tests/test_stable_diffusion.py b/xinference/model/image/tests/test_stable_diffusion.py index e4da8014d0..04cb607201 100644 --- a/xinference/model/image/tests/test_stable_diffusion.py +++ b/xinference/model/image/tests/test_stable_diffusion.py @@ -18,6 +18,8 @@ import os.path import shutil import tempfile +import threading +import time import uuid from io import BytesIO @@ -195,6 +197,62 @@ def test_restful_api_for_image_with_mlsd_controlnet(setup): logger.info("test result %s", r) +@pytest.mark.parametrize("model_name", ["sd-turbo"]) +def test_restful_api_abort(setup, model_name): + endpoint, _ = setup + from ....client import Client + + client = Client(endpoint) + + model_uid = client.launch_model( + model_uid="my_controlnet", + model_name=model_name, + model_type="image", + ) + model = client.get_model(model_uid) + + request_id = str(uuid.uuid4()) + client.abort_request(model_uid, request_id, 1) + time.sleep(2) + r = model.text_to_image( + prompt="A cinematic shot of a baby raccoon wearing an intricate italian priest robe.", + size="512*512", + num_inference_steps=10, + request_id=request_id, + ) + assert "created" in r + + request_id = str(uuid.uuid4()) + client.abort_request(model_uid, request_id) + with pytest.raises( + RuntimeError, match=f"The request has been aborted: {request_id}" + ): + model.text_to_image( + prompt="A cinematic shot of a baby raccoon wearing an intricate italian priest robe.", + size="512*512", + num_inference_steps=10, + request_id=request_id, + ) + + request_id = str(uuid.uuid4()) + + def _abort(): + time.sleep(1) + client.abort_request(model_uid, request_id) + + t = threading.Thread(target=_abort) + t.start() + with pytest.raises( + RuntimeError, match=f"The request has been cancelled: {request_id}" + ): + model.text_to_image( + prompt="A cinematic shot of a baby raccoon wearing an intricate italian priest robe.", + size="512*512", + num_inference_steps=10, + request_id=request_id, + ) + + @pytest.mark.parametrize("model_name", ["sd-turbo", "sdxl-turbo"]) def test_restful_api_for_sd_turbo(setup, model_name): if model_name == "sdxl-turbo": From 38728b66c4888e0beda58781da7076593fbb12ac Mon Sep 17 00:00:00 2001 From: Adam Ning Date: Wed, 13 Nov 2024 13:22:59 +0800 Subject: [PATCH 3/7] FEAT: Add qwen2.5-coder 0.5B 1.5B 3B 14B 32B (#2543) --- xinference/model/llm/llm_family.json | 180 ++++++++++++++- .../model/llm/llm_family_modelscope.json | 211 ++++++++++++++++++ 2 files changed, 387 insertions(+), 4 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index d8c2f4aa29..472f35b6eb 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -8205,6 +8205,16 @@ ], "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).", "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "0_5", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen2.5-Coder-0.5B" + }, { "model_format": "pytorch", "model_size_in_billions": "1_5", @@ -8213,8 +8223,17 @@ "8-bit", "none" ], - "model_id": "Qwen/Qwen2.5-Coder-1.5B", - "model_revision": "d3586cfe793730945f8e4d7ef31032a3ee50247d" + "model_id": "Qwen/Qwen2.5-Coder-1.5B" + }, + { + "model_format": "pytorch", + "model_size_in_billions": "3", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen2.5-Coder-3B" }, { "model_format": "pytorch", @@ -8224,8 +8243,27 @@ "8-bit", "none" ], - "model_id": "Qwen/Qwen2.5-Coder-7B", - "model_revision": "30b6a7e874a78d46b80fa1db3194ea427dd41b08" + "model_id": "Qwen/Qwen2.5-Coder-7B" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 14, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen2.5-Coder-14B" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 32, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen2.5-Coder-32B" } ] }, @@ -8243,6 +8281,16 @@ ], "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).", "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "0_5", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen2.5-Coder-0.5B-Instruct" + }, { "model_format": "pytorch", "model_size_in_billions": "1_5", @@ -8253,6 +8301,16 @@ ], "model_id": "Qwen/Qwen2.5-Coder-1.5B-Instruct" }, + { + "model_format": "pytorch", + "model_size_in_billions": "3", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen2.5-Coder-3B-Instruct" + }, { "model_format": "pytorch", "model_size_in_billions": 7, @@ -8263,6 +8321,53 @@ ], "model_id": "Qwen/Qwen2.5-Coder-7B-Instruct" }, + { + "model_format": "pytorch", + "model_size_in_billions": 14, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen2.5-Coder-14B-Instruct" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 32, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct" + }, + { + "model_format": "gptq", + "model_size_in_billions": "0_5", + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-{quantization}" + }, + { + "model_format": "gptq", + "model_size_in_billions": "1_5", + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-{quantization}" + }, + { + "model_format": "gptq", + "model_size_in_billions": "3", + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-{quantization}" + }, { "model_format": "gptq", "model_size_in_billions": "7", @@ -8272,6 +8377,73 @@ ], "model_id": "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-{quantization}" }, + { + "model_format": "gptq", + "model_size_in_billions": "14", + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-{quantization}" + }, + { + "model_format": "gptq", + "model_size_in_billions": "32", + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-{quantization}" + }, + { + "model_format": "awq", + "model_size_in_billions": "0_5", + "quantizations": [ + "Int4" + ], + "model_id": "Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ" + }, + { + "model_format": "awq", + "model_size_in_billions": "1_5", + "quantizations": [ + "Int4" + ], + "model_id": "Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ" + }, + { + "model_format": "awq", + "model_size_in_billions": "3", + "quantizations": [ + "Int4" + ], + "model_id": "Qwen/Qwen2.5-Coder-3B-Instruct-AWQ" + }, + { + "model_format": "awq", + "model_size_in_billions": "7", + "quantizations": [ + "Int4" + ], + "model_id": "Qwen/Qwen2.5-Coder-7B-Instruct-AWQ" + }, + { + "model_format": "awq", + "model_size_in_billions": "14", + "quantizations": [ + "Int4" + ], + "model_id": "Qwen/Qwen2.5-Coder-14B-Instruct-AWQ" + }, + { + "model_format": "awq", + "model_size_in_billions": "32", + "quantizations": [ + "Int4" + ], + "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct-AWQ" + }, + { "model_format": "ggufv2", "model_size_in_billions": "1_5", diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index bfd3d09a4a..f8598d3602 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -5907,6 +5907,18 @@ ], "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).", "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "0_5", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/Qwen2.5-Coder-0.5B", + "model_revision": "master", + "model_hub": "modelscope" + }, { "model_format": "pytorch", "model_size_in_billions": "1_5", @@ -5919,6 +5931,18 @@ "model_revision": "master", "model_hub": "modelscope" }, + { + "model_format": "pytorch", + "model_size_in_billions": "3", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/Qwen2.5-Coder-3B", + "model_revision": "master", + "model_hub": "modelscope" + }, { "model_format": "pytorch", "model_size_in_billions": 7, @@ -5930,6 +5954,30 @@ "model_id": "qwen/Qwen2.5-Coder-7B", "model_revision": "master", "model_hub": "modelscope" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 14, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/Qwen2.5-Coder-14B", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 32, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/Qwen2.5-Coder-32B", + "model_revision": "master", + "model_hub": "modelscope" } ] }, @@ -5947,6 +5995,18 @@ ], "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).", "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "0_5", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct", + "model_revision": "master", + "model_hub": "modelscope" + }, { "model_format": "pytorch", "model_size_in_billions": "1_5", @@ -5958,6 +6018,17 @@ "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct", "model_revision": "master", "model_hub": "modelscope" + }, { + "model_format": "pytorch", + "model_size_in_billions": "3", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/Qwen2.5-Coder-3B-Instruct", + "model_revision": "master", + "model_hub": "modelscope" }, { "model_format": "pytorch", @@ -5971,6 +6042,63 @@ "model_revision": "master", "model_hub": "modelscope" }, + { + "model_format": "pytorch", + "model_size_in_billions": 14, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/Qwen2.5-Coder-14B-Instruct", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 32, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/Qwen2.5-Coder-32B-Instruct", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "gptq", + "model_size_in_billions": "0_5", + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-{quantization}", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "gptq", + "model_size_in_billions": "1_5", + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-{quantization}", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "gptq", + "model_size_in_billions": 3, + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-{quantization}", + "model_revision": "master", + "model_hub": "modelscope" + }, { "model_format": "gptq", "model_size_in_billions": 7, @@ -5982,6 +6110,89 @@ "model_revision": "master", "model_hub": "modelscope" }, + { + "model_format": "gptq", + "model_size_in_billions": 14, + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-{quantization}", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "gptq", + "model_size_in_billions": 32, + "quantizations": [ + "Int4", + "Int8" + ], + "model_id": "qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-{quantization}", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "awq", + "model_size_in_billions": "0_5", + "quantizations": [ + "Int4" + ], + "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "awq", + "model_size_in_billions": "1_5", + "quantizations": [ + "Int4" + ], + "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "awq", + "model_size_in_billions": 3, + "quantizations": [ + "Int4" + ], + "model_id": "qwen/Qwen2.5-Coder-3B-Instruct-AWQ", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "awq", + "model_size_in_billions": 7, + "quantizations": [ + "Int4" + ], + "model_id": "qwen/Qwen2.5-Coder-7B-Instruct-AWQ", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "awq", + "model_size_in_billions": 14, + "quantizations": [ + "Int4" + ], + "model_id": "qwen/Qwen2.5-Coder-14B-Instruct-AWQ", + "model_revision": "master", + "model_hub": "modelscope" + }, + { + "model_format": "awq", + "model_size_in_billions": 32, + "quantizations": [ + "Int4" + ], + "model_id": "qwen/Qwen2.5-Coder-32B-Instruct-AWQ", + "model_revision": "master", + "model_hub": "modelscope" + }, + { "model_format": "ggufv2", "model_size_in_billions": "1_5", From ef6ce75aa1f23c268bb2d148712766381c74a87e Mon Sep 17 00:00:00 2001 From: Xuye Qin Date: Thu, 14 Nov 2024 11:01:12 +0800 Subject: [PATCH 4/7] FEAT: support kvcache in multi-round chat for MLX (#2534) --- xinference/model/llm/mlx/core.py | 47 +++++++++++++++++++++- xinference/model/llm/mlx/tests/test_mlx.py | 6 +++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index d01324fbf5..0ed55b2fc8 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -17,7 +17,8 @@ import sys import time import uuid -from typing import Dict, Iterator, List, Optional, TypedDict, Union +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict, Union from ....fields import max_tokens_field from ....types import ( @@ -53,6 +54,14 @@ class MLXGenerateConfig(TypedDict, total=False): stream: bool stream_options: Optional[Union[dict, None]] tools: Optional[List[Dict]] + lora_name: Optional[str] + + +@dataclass +class PromptCache: + cache: List[Any] = field(default_factory=list) + model_key: Tuple[str, Optional[str]] = ("", None) + tokens: List[int] = field(default_factory=list) class MLXModel(LLM): @@ -69,6 +78,8 @@ def __init__( super().__init__(model_uid, model_family, model_spec, quantization, model_path) self._use_fast_tokenizer = True self._model_config: MLXModelConfig = self._sanitize_model_config(model_config) + self._max_kv_size = None + self._prompt_cache = None if peft_model is not None: raise ValueError("MLX engine has not supported lora yet") @@ -127,6 +138,9 @@ def _load_model(self, **kwargs): logger.debug(f"Setting cache limit to {cache_limit_gb} GB") mx.metal.set_cache_limit(cache_limit_gb * 1024 * 1024 * 1024) + self._max_kv_size = kwargs.get("max_kv_size", None) + self._prompt_cache = PromptCache() + return load( self.model_path, tokenizer_config=tokenizer_config, @@ -156,6 +170,27 @@ def match( return False return True + def _get_prompt_cache(self, prompt, lora_name: Optional[str] = None): + from mlx_lm.models.cache import make_prompt_cache + + assert self._prompt_cache is not None + cache_len = len(self._prompt_cache.tokens) + model_key = (self.model_path, lora_name) + if ( + self._prompt_cache.model_key != model_key + or cache_len >= len(prompt) + or self._prompt_cache.tokens != prompt[:cache_len] + ): + self._prompt_cache.model_key = model_key + self._prompt_cache.cache = make_prompt_cache(self._model, self._max_kv_size) + self._prompt_cache.tokens = [] + logger.debug("Making new prompt cache for %s", self.model_uid) + else: + prompt = prompt[cache_len:] + logger.debug("Cache hit for %s", self.model_uid) + self._prompt_cache.tokens.extend(prompt) + return prompt + def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): import mlx.core as mx from mlx_lm.utils import generate_step @@ -167,6 +202,7 @@ def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): chunk_id = str(uuid.uuid4()) stop_token_ids = kwargs.get("stop_token_ids", []) stream = kwargs.get("stream", False) + lora_name = kwargs.get("lora_name") stream_options = kwargs.pop("stream_options", None) include_usage = ( stream_options["include_usage"] @@ -174,12 +210,15 @@ def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): else False ) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + prompt_token_ids = tokenizer.encode(prompt) + prompt_token_ids = self._get_prompt_cache(prompt_token_ids, lora_name) + prompt_tokens = mx.array(prompt_token_ids) input_echo_len = len(prompt_tokens) i = 0 start = time.time() output = "" + tokens = [] for (token, _), i in zip( generate_step( prompt_tokens, @@ -189,9 +228,11 @@ def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): repetition_context_size=kwargs["repetition_context_size"], top_p=kwargs["top_p"], logit_bias=kwargs["logit_bias"], + prompt_cache=self._prompt_cache.cache, # type: ignore ), range(max_tokens), ): + tokens.append(token) if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore break @@ -230,6 +271,8 @@ def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig): f"Average generation speed: {i / (time.time() - start):.2f} tokens/s." ) + self._prompt_cache.tokens.extend(tokens) # type: ignore + if i == max_tokens - 1: finish_reason = "length" else: diff --git a/xinference/model/llm/mlx/tests/test_mlx.py b/xinference/model/llm/mlx/tests/test_mlx.py index b1d0682e5b..2807dabb04 100644 --- a/xinference/model/llm/mlx/tests/test_mlx.py +++ b/xinference/model/llm/mlx/tests/test_mlx.py @@ -40,3 +40,9 @@ def test_load_mlx(setup): completion = model.chat(messages) assert "content" in completion["choices"][0]["message"] assert len(completion["choices"][0]["message"]["content"]) != 0 + content = completion["choices"][0]["message"]["content"] + messages.append({"role": "assistant", "content": content}) + messages.append({"role": "user", "content": "explain it"}) + completion = model.chat(messages) + assert "content" in completion["choices"][0]["message"] + assert len(completion["choices"][0]["message"]["content"]) != 0 From 042eb5baab0515a59b201677d7d20c9ed842fdb9 Mon Sep 17 00:00:00 2001 From: Xuye Qin Date: Thu, 14 Nov 2024 11:01:38 +0800 Subject: [PATCH 5/7] BUG: fix variant error for image model (#2547) --- xinference/model/image/stable_diffusion/core.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xinference/model/image/stable_diffusion/core.py b/xinference/model/image/stable_diffusion/core.py index c5a9b33f86..e0f7e5c886 100644 --- a/xinference/model/image/stable_diffusion/core.py +++ b/xinference/model/image/stable_diffusion/core.py @@ -17,9 +17,11 @@ import inspect import itertools import logging +import os import re import sys import warnings +from glob import glob from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import PIL.Image @@ -194,8 +196,9 @@ def load(self): if sys.platform != "darwin" and torch_dtype is None: # The following params crashes on Mac M2 self._torch_dtype = self._kwargs["torch_dtype"] = torch.float16 - self._kwargs["variant"] = "fp16" - self._kwargs["use_safetensors"] = True + self._kwargs["use_safetensors"] = any( + glob(os.path.join(self._model_path, "*/*.safetensors")) + ) if isinstance(torch_dtype, str): self._kwargs["torch_dtype"] = getattr(torch, torch_dtype) From 7a0bb6035d5546c9eab16fad1adb3fe27b10c757 Mon Sep 17 00:00:00 2001 From: Bryan Date: Fri, 15 Nov 2024 13:14:45 +0800 Subject: [PATCH 6/7] ENH: add normalize to rerank model (#2509) Co-authored-by: libing Co-authored-by: codingl2k1 --- xinference/api/restful_api.py | 12 ++++++------ xinference/client/restful/restful_client.py | 1 + xinference/model/rerank/core.py | 15 +++++++++++---- xinference/model/rerank/tests/test_rerank.py | 5 ++--- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index d9d4678911..d8b4564106 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -115,6 +115,7 @@ class RerankRequest(BaseModel): return_documents: Optional[bool] = False return_len: Optional[bool] = False max_chunks_per_doc: Optional[int] = None + kwargs: Optional[str] = None class TextToImageRequest(BaseModel): @@ -1315,11 +1316,6 @@ async def rerank(self, request: Request) -> Response: payload = await request.json() body = RerankRequest.parse_obj(payload) model_uid = body.model - kwargs = { - key: value - for key, value in payload.items() - if key not in RerankRequest.__annotations__.keys() - } try: model = await (await self._get_supervisor_ref()).get_model(model_uid) @@ -1333,6 +1329,10 @@ async def rerank(self, request: Request) -> Response: raise HTTPException(status_code=500, detail=str(e)) try: + if body.kwargs is not None: + parsed_kwargs = json.loads(body.kwargs) + else: + parsed_kwargs = {} scores = await model.rerank( body.documents, body.query, @@ -1340,7 +1340,7 @@ async def rerank(self, request: Request) -> Response: max_chunks_per_doc=body.max_chunks_per_doc, return_documents=body.return_documents, return_len=body.return_len, - **kwargs, + **parsed_kwargs, ) return Response(scores, media_type="application/json") except RuntimeError as re: diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index ed71a7bf05..ab03c566c1 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -174,6 +174,7 @@ def rerank( "max_chunks_per_doc": max_chunks_per_doc, "return_documents": return_documents, "return_len": return_len, + "kwargs": json.dumps(kwargs), } request_body.update(kwargs) response = requests.post(url, json=request_body, headers=self.auth_headers) diff --git a/xinference/model/rerank/core.py b/xinference/model/rerank/core.py index ffd1485eb5..4eb071ecd1 100644 --- a/xinference/model/rerank/core.py +++ b/xinference/model/rerank/core.py @@ -179,6 +179,7 @@ def _auto_detect_type(model_path): return rerank_type def load(self): + logger.info("Loading rerank model: %s", self._model_path) flash_attn_installed = importlib.util.find_spec("flash_attn") is not None if ( self._auto_detect_type(self._model_path) != "normal" @@ -189,6 +190,7 @@ def load(self): "will force set `use_fp16` to True" ) self._use_fp16 = True + if self._model_spec.type == "normal": try: import sentence_transformers @@ -250,22 +252,27 @@ def rerank( **kwargs, ) -> Rerank: assert self._model is not None - if kwargs: - raise ValueError("rerank hasn't support extra parameter.") if max_chunks_per_doc is not None: raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.") + logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model) sentence_combinations = [[query, doc] for doc in documents] # reset n tokens self._model.model.n_tokens = 0 if self._model_spec.type == "normal": similarity_scores = self._model.predict( - sentence_combinations, convert_to_numpy=False, convert_to_tensor=True + sentence_combinations, + convert_to_numpy=False, + convert_to_tensor=True, + **kwargs, ).cpu() if similarity_scores.dtype == torch.bfloat16: similarity_scores = similarity_scores.float() else: # Related issue: https://github.com/xorbitsai/inference/issues/1775 - similarity_scores = self._model.compute_score(sentence_combinations) + similarity_scores = self._model.compute_score( + sentence_combinations, **kwargs + ) + if not isinstance(similarity_scores, Sequence): similarity_scores = [similarity_scores] elif ( diff --git a/xinference/model/rerank/tests/test_rerank.py b/xinference/model/rerank/tests/test_rerank.py index 4ceac1c811..b76572dbad 100644 --- a/xinference/model/rerank/tests/test_rerank.py +++ b/xinference/model/rerank/tests/test_rerank.py @@ -118,9 +118,8 @@ def test_restful_api(model_name, setup): kwargs = { "invalid": "invalid", } - with pytest.raises(RuntimeError) as err: - scores = model.rerank(corpus, query, **kwargs) - assert "hasn't support" in str(err.value) + with pytest.raises(RuntimeError): + model.rerank(corpus, query, **kwargs) def test_from_local_uri(): From 4c96475b8f90e354aa1b47856fda4db098b62b65 Mon Sep 17 00:00:00 2001 From: codingl2k1 <138426806+codingl2k1@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:33:11 +0100 Subject: [PATCH 7/7] ENH: Update fish audio (#2555) Co-authored-by: qinxuye --- .github/workflows/python.yaml | 3 + setup.cfg | 4 + xinference/deploy/docker/requirements.txt | 4 +- xinference/deploy/docker/requirements_cpu.txt | 4 +- xinference/model/audio/model_spec.json | 2 +- .../fish_speech/configs/__init__.py | 0 .../fish_speech/configs/lora/__init__.py | 0 .../fish_speech/fish_speech/conversation.py | 254 +++++++ .../fish_speech/datasets/__init__.py | 0 .../fish_speech/datasets/protos/__init__.py | 0 .../fish_speech/i18n/locale/__init__.py | 0 .../fish_speech/i18n/locale/en_US.json | 3 +- .../fish_speech/i18n/locale/es_ES.json | 3 +- .../fish_speech/i18n/locale/ja_JP.json | 4 +- .../fish_speech/i18n/locale/ko_KR.json | 123 ++++ .../fish_speech/i18n/locale/zh_CN.json | 3 +- .../fish_speech/models/__init__.py | 0 .../fish_speech/models/text2semantic/llama.py | 87 ++- .../models/vqgan/modules/__init__.py | 0 .../models/vqgan/modules/firefly.py | 18 +- .../fish_speech/models/vqgan/modules/fsq.py | 2 +- .../fish_speech/fish_speech/text/clean.py | 33 +- .../fish_speech/fish_speech/utils/__init__.py | 3 +- .../fish_speech/fish_speech/utils/utils.py | 22 + .../fish_speech/fish_speech/webui/__init__.py | 0 .../fish_speech/webui/launch_utils.py | 2 +- .../fish_speech/fish_speech/webui/manage.py | 2 +- .../thirdparty/fish_speech/tools/api.py | 653 ++++++++++++++++-- .../thirdparty/fish_speech/tools/commons.py | 35 - .../thirdparty/fish_speech/tools/e2e_webui.py | 232 +++++++ .../thirdparty/fish_speech/tools/fish_e2e.py | 298 ++++++++ .../fish_speech/tools/llama/__init__.py | 0 .../fish_speech/tools/llama/generate.py | 402 ++++++++++- .../fish_speech/tools/msgpack_api.py | 119 +++- .../thirdparty/fish_speech/tools/post_api.py | 52 +- .../thirdparty/fish_speech/tools/schema.py | 187 +++++ .../fish_speech/tools/vqgan/__init__.py | 0 .../fish_speech/tools/vqgan/extract_vq.py | 8 +- .../fish_speech/tools/vqgan/inference.py | 5 +- .../thirdparty/fish_speech/tools/webui.py | 213 ++++-- 40 files changed, 2505 insertions(+), 275 deletions(-) delete mode 100644 xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py delete mode 100644 xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py delete mode 100644 xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py delete mode 100644 xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py delete mode 100644 xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py create mode 100644 xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json delete mode 100644 xinference/thirdparty/fish_speech/fish_speech/models/__init__.py delete mode 100644 xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py delete mode 100644 xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py delete mode 100644 xinference/thirdparty/fish_speech/tools/commons.py create mode 100644 xinference/thirdparty/fish_speech/tools/e2e_webui.py create mode 100644 xinference/thirdparty/fish_speech/tools/fish_e2e.py delete mode 100644 xinference/thirdparty/fish_speech/tools/llama/__init__.py create mode 100644 xinference/thirdparty/fish_speech/tools/schema.py delete mode 100644 xinference/thirdparty/fish_speech/tools/vqgan/__init__.py diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 915811c7b0..a2a42796f2 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -180,6 +180,9 @@ jobs: ${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y "faster_whisper" ${{ env.SELF_HOST_PYTHON }} -m pip install -U accelerate ${{ env.SELF_HOST_PYTHON }} -m pip install -U verovio + ${{ env.SELF_HOST_PYTHON }} -m pip install -U cachetools + ${{ env.SELF_HOST_PYTHON }} -m pip install -U silero-vad + ${{ env.SELF_HOST_PYTHON }} -m pip install -U pydantic ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \ --disable-warnings \ --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_continuous_batching.py && \ diff --git a/setup.cfg b/setup.cfg index 2d420de8fd..0bbd2f5021 100644 --- a/setup.cfg +++ b/setup.cfg @@ -129,6 +129,8 @@ all = natsort # For Fish Speech loralib # For Fish Speech ormsgpack # For Fish Speech + cachetools # For Fish Speech + silero-vad # For Fish Speech qwen-vl-utils # For qwen2-vl datamodel_code_generator # for minicpm-4B jsonschema # for minicpm-4B @@ -210,6 +212,8 @@ audio = natsort # For Fish Speech loralib # For Fish Speech ormsgpack # For Fish Speech + cachetools # For Fish Speech + silero-vad # For Fish Speech doc = ipython>=6.5.0 sphinx>=3.0.0 diff --git a/xinference/deploy/docker/requirements.txt b/xinference/deploy/docker/requirements.txt index 6671d4bfb4..79d4e2defd 100644 --- a/xinference/deploy/docker/requirements.txt +++ b/xinference/deploy/docker/requirements.txt @@ -7,7 +7,7 @@ click tqdm>=4.27 tabulate requests -pydantic +pydantic>2 fastapi>=0.110.3 uvicorn huggingface-hub>=0.19.4 @@ -72,6 +72,8 @@ loguru # For Fish Speech natsort # For Fish Speech loralib # For Fish Speech ormsgpack # For Fish Speech +cachetools # For Fish Speech +silero-vad # For Fish Speech qwen-vl-utils # For qwen2-vl datamodel_code_generator # for minicpm-4B jsonschema # for minicpm-4B diff --git a/xinference/deploy/docker/requirements_cpu.txt b/xinference/deploy/docker/requirements_cpu.txt index e4410a0358..4105a2e709 100644 --- a/xinference/deploy/docker/requirements_cpu.txt +++ b/xinference/deploy/docker/requirements_cpu.txt @@ -6,7 +6,7 @@ click tqdm>=4.27 tabulate requests -pydantic +pydantic>2 fastapi>=0.110.3 uvicorn huggingface-hub>=0.19.4 @@ -67,6 +67,8 @@ loguru # For Fish Speech natsort # For Fish Speech loralib # For Fish Speech ormsgpack # For Fish Speech +cachetools # For Fish Speech +silero-vad # For Fish Speech qwen-vl-utils # For qwen2-vl datamodel_code_generator # for minicpm-4B jsonschema # for minicpm-4B diff --git a/xinference/model/audio/model_spec.json b/xinference/model/audio/model_spec.json index b7c436d474..912d8399e9 100644 --- a/xinference/model/audio/model_spec.json +++ b/xinference/model/audio/model_spec.json @@ -159,7 +159,7 @@ "model_name": "FishSpeech-1.4", "model_family": "FishAudio", "model_id": "fishaudio/fish-speech-1.4", - "model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d", + "model_revision": "069c573759936b35191d3380deb89183c0656f59", "model_ability": "text-to-audio", "multilingual": true } diff --git a/xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/conversation.py b/xinference/thirdparty/fish_speech/fish_speech/conversation.py index c9ca0ef918..9bbc1cdb6c 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/conversation.py +++ b/xinference/thirdparty/fish_speech/fish_speech/conversation.py @@ -1,2 +1,256 @@ +from dataclasses import dataclass, field +from typing import Literal + +import torch +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast + +IM_START_TOKEN = "<|im_start|>" +IM_END_TOKEN = "<|im_end|>" SEMANTIC_TOKEN = "<|semantic|>" +MEL_TOKEN = "<|mel|>" +PHONEME_START_TOKEN = "<|phoneme_start|>" +PHONEME_END_TOKEN = "<|phoneme_end|>" +ALL_SPECIAL_TOKENS = [ + IM_START_TOKEN, + IM_END_TOKEN, + SEMANTIC_TOKEN, + MEL_TOKEN, + PHONEME_START_TOKEN, + PHONEME_END_TOKEN, +] + CODEBOOK_PAD_TOKEN_ID = 0 + + +class FishTokenizerConfig(PretrainedConfig): + share_codebook_embeddings: bool = True + codebook_size: int = 1024 + num_codebooks: int = 8 + + +class FishTokenizerFast(PreTrainedTokenizerFast): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True) + self.codebook_size = kwargs.pop("codebook_size", 1024) + self.num_codebooks = kwargs.pop("num_codebooks", 8) + + +AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast) + + +@dataclass(kw_only=True) +class BasePart: + pass + + +@dataclass(kw_only=True) +class VQPart(BasePart): + codes: torch.Tensor + + +@dataclass(kw_only=True) +class TextPart(BasePart): + text: str + + +@dataclass(kw_only=True) +class MelPart(BasePart): + mels: torch.Tensor + + +@dataclass(kw_only=True) +class EncodedMessage: + tokens: torch.Tensor + labels: torch.Tensor + vq_parts: list[torch.Tensor] + mel_parts: list[torch.Tensor] + vq_require_losses: torch.Tensor | None = None + + +@dataclass(kw_only=True) +class Message: + role: Literal["system", "user", "assistant"] + parts: list[VQPart | TextPart | MelPart] = field(default_factory=list) + add_im_start: bool = True + add_im_end: bool = True + cal_loss: bool = False + + # By default, ignore the loss of the auto-generated im_start token + ignore_im_start_loss: bool = True + + def encode( + self: "Message", + tokenizer: AutoTokenizer, + ) -> EncodedMessage: + all_tokens = [] + all_labels = [] + + # Multi-modal tokens + vq_parts = [] + mel_parts = [] + + semantic_id, mel_id = tokenizer.convert_tokens_to_ids( + [SEMANTIC_TOKEN, MEL_TOKEN] + ) + + parts = self.parts.copy() + if self.add_im_start: + parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n")) + + if self.add_im_end: + parts.append(TextPart(text="<|im_end|>")) + + for part in parts: + if isinstance(part, TextPart): + tokens = tokenizer.encode( + part.text, + add_special_tokens=False, + truncation=False, + return_tensors="pt", + ).int()[0] + elif isinstance(part, VQPart): + tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id + codes = part.codes.clone() + 1 + + if getattr(tokenizer, "share_codebook_embeddings", True) is False: + for i in range(len(codes)): + codes[i] += tokenizer.codebook_size * i + + vq_parts.append(codes) + elif isinstance(part, MelPart): + tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id + mel_parts.append(part.mels) + else: + raise ValueError(f"Unsupported part type: {type(part)}") + + all_tokens.append(tokens) + if self.cal_loss: + all_labels.append(tokens.clone()) + else: + all_labels.append(torch.full_like(tokens, -100)) + + tokens = torch.cat(all_tokens, dim=0) + labels = torch.cat(all_labels, dim=0) + assert tokens.shape == labels.shape + + if self.ignore_im_start_loss and self.add_im_start: + labels[: len(all_tokens[0])] = -100 + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + mel_parts=mel_parts, + ) + + +@dataclass +class Conversation: + messages: list[Message] + + def encode( + self: "Conversation", + tokenizer: AutoTokenizer, + add_shift: bool = True, + ) -> EncodedMessage: + # Build the input_ids and labels + tokens = [] + labels = [] + vq_parts = [] + mel_parts = [] + vq_require_losses = [] + + for message in self.messages: + encoded = message.encode( + tokenizer, + ) + tokens.append(encoded.tokens) + labels.append(encoded.labels) + vq_parts.extend(encoded.vq_parts) + mel_parts.extend(encoded.mel_parts) + vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts)) + + tokens = torch.cat(tokens, dim=0) + labels = torch.cat(labels, dim=0) + vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool) + + if add_shift: + tokens = tokens[:-1] + labels = labels[1:] + + assert tokens.dtype in [ + torch.int, + torch.long, + ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}" + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + mel_parts=mel_parts, + vq_require_losses=vq_require_losses, + ) + + def encode_for_inference( + self: "Conversation", + tokenizer: AutoTokenizer, + num_codebooks: int, + ) -> EncodedMessage: + encoded = self.encode(tokenizer, add_shift=False) + tokens = encoded.tokens + values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int) + values[0] = tokens + + if encoded.vq_parts is None or len(encoded.vq_parts) == 0: + return values + + semantic_id, mel_id = tokenizer.convert_tokens_to_ids( + [SEMANTIC_TOKEN, MEL_TOKEN] + ) + vq_parts = encoded.vq_parts + vq_parts = torch.cat(vq_parts, dim=1) + values[1:, tokens == semantic_id] = vq_parts + return values + + def visualize(self: "Conversation", tokenizer: AutoTokenizer): + encoded = self.encode(tokenizer, add_shift=False) + + print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="") + print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="") + + for tok, lab in zip(encoded.tokens, encoded.labels): + val = tokenizer.decode(tok, skip_special_tokens=False) + if val == "\n": + val = "\\n\n" + + if lab == -100: + print_in_green(val) + else: + print_in_blue(val) + + print() + + +if __name__ == "__main__": + message0 = Message( + role="user", + parts=[ + TextPart(text="Hello, how are you?"), + VQPart(codes=torch.zeros((4, 10))), + ], + cal_loss=False, + ) + + message1 = Message( + role="assistant", + parts=[TextPart(text="I'm fine, thank you.")], + cal_loss=True, + ) + conversation = Conversation([message0, message1]) + tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct") + conversation.visualize(tokenizer) + + encoded = conversation.encode(tokenizer) + print(encoded) + print(tokenizer.batch_decode(encoded.tokens)) diff --git a/xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json index 6e280c236e..d36c774313 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json @@ -118,5 +118,6 @@ "new": "new", "Realtime Transform Text": "Realtime Transform Text", "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", - "Text Normalization": "Text Normalization" + "Text Normalization": "Text Normalization", + "Select Example Audio": "Select Example Audio" } diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json index 3285341f68..7a4757967d 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json @@ -118,5 +118,6 @@ "new": "nuevo", "Realtime Transform Text": "Transformación de Texto en Tiempo Real", "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", - "Text Normalization": "Normalización de Texto" + "Text Normalization": "Normalización de Texto", + "Select Example Audio": "Selecionar áudio de exemplo" } diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json index d30bac7bcd..863b8b0b41 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json @@ -118,6 +118,6 @@ "new": "新規", "Realtime Transform Text": "リアルタイム変換テキスト", "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", - "Text Normalization": "テキスト正規化" - + "Text Normalization": "テキスト正規化", + "Select Example Audio": "サンプル音声を選択" } diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json new file mode 100644 index 0000000000..180263874b --- /dev/null +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.", + "Accumulate Gradient Batches": "그라디언트 배치 누적", + "Add to Processing Area": "처리 영역에 추가", + "Added path successfully!": "경로가 성공적으로 추가되었습니다!", + "Advanced Config": "고급 설정", + "Base LLAMA Model": "기본 LLAMA 모델", + "Batch Inference": "배치 추론", + "Batch Size": "배치 크기", + "Changing with the Model Path": "모델 경로에 따라 변경 중", + "Chinese": "중국어", + "Compile Model": "모델 컴파일", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.", + "Copy": "복사", + "Data Preprocessing": "데이터 전처리", + "Data Preprocessing Path": "데이터 전처리 경로", + "Data Source": "데이터 소스", + "Decoder Model Config": "디코더 모델 설정", + "Decoder Model Path": "디코더 모델 경로", + "Disabled": "비활성화 됨", + "Enable Reference Audio": "참고 음성 활성화", + "English": "영어", + "Error Message": "오류 메시지", + "File Preprocessing": "파일 전처리", + "Generate": "생성", + "Generated Audio": "생성된 오디오", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.", + "Infer interface is closed": "추론 인터페이스가 닫혔습니다.", + "Inference Configuration": "추론 설정", + "Inference Server Configuration": "추론 서버 설정", + "Inference Server Error": "추론 서버 오류", + "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.", + "Initial Learning Rate": "초기 학습률", + "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로", + "Input Text": "입력 텍스트", + "Invalid path: {}": "유효하지 않은 경로: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.", + "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)", + "Japanese": "일본어", + "LLAMA Configuration": "LLAMA 설정", + "LLAMA Model Config": "LLAMA 모델 설정", + "LLAMA Model Path": "LLAMA 모델 경로", + "Labeling Device": "라벨링 장치", + "LoRA Model to be merged": "병합할 LoRA 모델", + "Maximum Audio Duration": "최대 오디오 길이", + "Maximum Length per Sample": "샘플당 최대 길이", + "Maximum Training Steps": "최대 학습 단계", + "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)", + "Merge": "병합", + "Merge LoRA": "LoRA 병합", + "Merge successfully": "성공적으로 병합 되었습니다.", + "Minimum Audio Duration": "최소 오디오 길이", + "Model Output Path": "모델 출력 경로", + "Model Size": "모델 크기", + "Move": "이동", + "Move files successfully": "파일이 성공적으로 이동되었습니다.", + "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.", + "No selected options": "옵션이 선택되지 않았습니다.", + "Number of Workers": "작업자 수", + "Open Inference Server": "추론 서버 열기", + "Open Labeler WebUI": "라벨러 WebUI 열기", + "Open Tensorboard": "Tensorboard 열기", + "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.", + "Optional Label Language": "선택적 라벨 언어", + "Optional online ver": "온라인 버전 선택", + "Output Path": "출력 경로", + "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.", + "Precision": "정밀도", + "Probability of applying Speaker Condition": "화자 조건 적용 확률", + "Put your text here.": "여기에 텍스트를 입력하세요.", + "Reference Audio": "참고 오디오", + "Reference Text": "참고 텍스트", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.", + "Remove Selected Data": "선택한 데이터 제거", + "Removed path successfully!": "경로가 성공적으로 제거되었습니다!", + "Repetition Penalty": "반복 패널티", + "Save model every n steps": "n 단계마다 모델 저장", + "Select LLAMA ckpt": "LLAMA ckpt 선택", + "Select VITS ckpt": "VITS ckpt 선택", + "Select VQGAN ckpt": "VQGAN ckpt 선택", + "Select source file processing method": "소스 파일 처리 방법 선택", + "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)", + "Selected: {}": "선택됨: {}", + "Speaker": "화자", + "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다", + "Start Training": "학습 시작", + "Streaming Audio": "스트리밍 오디오", + "Streaming Generate": "스트리밍 생성", + "Tensorboard Host": "Tensorboard 호스트", + "Tensorboard Log Path": "Tensorboard 로그 경로", + "Tensorboard Port": "Tensorboard 포트", + "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다", + "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.", + "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.", + "Training Configuration": "학습 설정", + "Training Error": "학습 오류", + "Training stopped": "학습이 중지되었습니다.", + "Type name of the speaker": "화자의 이름을 입력하세요.", + "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.", + "Use LoRA": "LoRA 사용", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.", + "Use filelist": "파일 목록 사용", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.", + "VITS Configuration": "VITS 설정", + "VQGAN Configuration": "VQGAN 설정", + "Validation Batch Size": "검증 배치 크기", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.", + "WebUI Host": "WebUI 호스트", + "WebUI Port": "WebUI 포트", + "Whisper Model": "Whisper 모델", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다", + "latest": "최신", + "new": "새로운", + "Realtime Transform Text": "실시간 텍스트 변환", + "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)", + "Text Normalization": "텍스트 정규화", + "Select Example Audio": "예시 오디오 선택" +} diff --git a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json index 3dd1a5cd1c..9068ef0b9a 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +++ b/xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json @@ -118,5 +118,6 @@ "new": "创建新的检查点", "Realtime Transform Text": "实时规范化文本", "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", - "Text Normalization": "文本规范化" + "Text Normalization": "文本规范化", + "Select Example Audio": "选择参考音频" } diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/models/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py b/xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py index 0725dfb9b7..6ea15e595f 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +++ b/xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py @@ -1,3 +1,4 @@ +import dataclasses import json import math from collections import OrderedDict @@ -57,6 +58,10 @@ class BaseModelArgs: # Initialize the model initializer_range: float = 0.02 + # Dummy vars + is_reward_model: bool = False + share_codebook_embeddings: bool = True + def __post_init__(self): if self.n_local_heads == -1: self.n_local_heads = self.n_head @@ -100,6 +105,28 @@ class NaiveModelArgs(BaseModelArgs): class DualARModelArgs(BaseModelArgs): model_type: str = "dual_ar" n_fast_layer: int = 4 + fast_dim: int | None = None + fast_n_head: int | None = None + fast_n_local_heads: int | None = None + fast_head_dim: int | None = None + fast_intermediate_size: int | None = None + fast_attention_qkv_bias: bool | None = None + + def __post_init__(self): + super().__post_init__() + + self.fast_dim = self.fast_dim or self.dim + self.fast_n_head = self.fast_n_head or self.n_head + self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads + self.fast_head_dim = self.fast_head_dim or self.head_dim + self.fast_intermediate_size = ( + self.fast_intermediate_size or self.intermediate_size + ) + self.fast_attention_qkv_bias = ( + self.fast_attention_qkv_bias + if self.fast_attention_qkv_bias is not None + else self.attention_qkv_bias + ) class KVCache(nn.Module): @@ -369,7 +396,10 @@ def from_pretrained( model = simple_quantizer.convert_for_runtime() weights = torch.load( - Path(path) / "model.pth", map_location="cpu", mmap=True + Path(path) / "model.pth", + map_location="cpu", + mmap=True, + weights_only=True, ) if "state_dict" in weights: @@ -471,20 +501,46 @@ class DualARTransformer(BaseTransformer): def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: super().__init__(config, init_weights=False, tokenizer=tokenizer) + # Project to fast dim if needed + if config.fast_dim is not None and config.fast_dim != config.dim: + self.fast_project_in = nn.Linear(config.dim, config.fast_dim) + else: + self.fast_project_in = nn.Identity() + # Fast transformer - self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim) + self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim) # The equivalent bs is so large that sdpa doesn't work + override_config = dataclasses.replace( + config, + dim=config.fast_dim, + n_head=config.fast_n_head, + n_local_heads=config.fast_n_local_heads, + head_dim=config.fast_head_dim, + intermediate_size=config.fast_intermediate_size, + attention_qkv_bias=config.fast_attention_qkv_bias, + ) + self.fast_layers = nn.ModuleList( - TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer) + TransformerBlock(override_config, use_sdpa=False) + for _ in range(config.n_fast_layer) ) - self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps) self.fast_output = nn.Linear( - config.dim, + config.fast_dim, config.codebook_size, bias=False, ) + self.register_buffer( + "fast_freqs_cis", + precompute_freqs_cis( + config.num_codebooks, + config.fast_dim // config.fast_n_head, + config.rope_base, + ), + persistent=False, + ) self.apply(self._init_weights) def setup_caches( @@ -492,7 +548,7 @@ def setup_caches( ): super().setup_caches(max_batch_size, max_seq_len, dtype) - head_dim = self.config.dim // self.config.n_head + head_dim = self.config.fast_dim // self.config.fast_n_head # Fast transformer # The max seq len here is the number of codebooks @@ -500,7 +556,7 @@ def setup_caches( b.attention.kv_cache = KVCache( max_batch_size, self.config.num_codebooks, - self.config.n_local_heads, + self.config.fast_n_local_heads, head_dim, dtype=dtype, ) @@ -513,13 +569,13 @@ def forward( parent_result = super().forward(inp, key_padding_mask) token_logits = parent_result.logits x = parent_result.hidden_states + x = self.fast_project_in(x) # Fast transformer fast_seq_len = self.config.num_codebooks fast_mask = self.causal_mask[ None, None, :fast_seq_len, :fast_seq_len ] # (B, N, Q, K) - fast_freqs_cis = self.freqs_cis[:fast_seq_len] # Drop the last token and rotate left codebooks = inp[:, 1:-1, 1:] @@ -542,9 +598,11 @@ def forward( for layer in self.fast_layers: if self.config.use_gradient_checkpointing and self.training: - x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True) + x = checkpoint( + layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True + ) else: - x = layer(x, fast_freqs_cis, fast_mask) + x = layer(x, self.fast_freqs_cis, fast_mask) # unflatten the batch and num_codebooks fast_out = self.fast_norm(x) @@ -584,7 +642,7 @@ def forward_generate_fast( fast_mask = self.causal_mask[ None, None, input_pos, : self.config.num_codebooks ] # (B, N, Q, K) - fast_freqs_cis = self.freqs_cis[input_pos] + fast_freqs_cis = self.fast_freqs_cis[input_pos] for layer in self.fast_layers: x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) @@ -595,6 +653,13 @@ def forward_generate_fast( return codebook_logits + def forward_generate( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> TransformerForwardResult: + x = super().forward_generate(x, input_pos) + x.hidden_states = self.fast_project_in(x.hidden_states) + return x + class TransformerBlock(nn.Module): def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py index aa21839b54..91fc9118cc 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +++ b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py @@ -102,8 +102,8 @@ def weight_norm(self, name="weight", dim=0): self.conv = weight_norm(self.conv, name=name, dim=dim) return self - def remove_weight_norm(self): - self.conv = remove_parametrizations(self.conv) + def remove_parametrizations(self, name="weight"): + self.conv = remove_parametrizations(self.conv, name) return self @@ -128,8 +128,8 @@ def weight_norm(self, name="weight", dim=0): self.conv = weight_norm(self.conv, name=name, dim=dim) return self - def remove_weight_norm(self): - self.conv = remove_parametrizations(self.conv) + def remove_parametrizations(self, name="weight"): + self.conv = remove_parametrizations(self.conv, name) return self @@ -178,9 +178,9 @@ def forward(self, x): def remove_parametrizations(self): for conv in self.convs1: - remove_parametrizations(conv, tensor_name="weight") + conv.remove_parametrizations() for conv in self.convs2: - remove_parametrizations(conv, tensor_name="weight") + conv.remove_parametrizations() class ParallelBlock(nn.Module): @@ -288,11 +288,11 @@ def forward(self, x): def remove_parametrizations(self): for up in self.ups: - remove_parametrizations(up, tensor_name="weight") + up.remove_parametrizations() for block in self.resblocks: block.remove_parametrizations() - remove_parametrizations(self.conv_pre, tensor_name="weight") - remove_parametrizations(self.conv_post, tensor_name="weight") + self.conv_pre.remove_parametrizations() + self.conv_post.remove_parametrizations() # DropPath copied from timm library diff --git a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py index 7ea4853376..954553bbfe 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +++ b/xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py @@ -99,7 +99,7 @@ def forward(self, z) -> FSQResult: if diff > 0: result.z = F.pad(result.z, (left, right)) elif diff < 0: - result.z = result.z[..., left:-right] + result.z = result.z[..., -left:right] return result diff --git a/xinference/thirdparty/fish_speech/fish_speech/text/clean.py b/xinference/thirdparty/fish_speech/fish_speech/text/clean.py index c228dfcd13..dbaf843d78 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/text/clean.py +++ b/xinference/thirdparty/fish_speech/fish_speech/text/clean.py @@ -1,6 +1,8 @@ import re SYMBOLS_MAPPING = { + "\n": "", + "…": ".", "“": "'", "”": "'", "‘": "'", @@ -13,7 +15,19 @@ ")": "", "(": "", ")": "", - "・": "·", + "・": "", + "·": "", + "「": "'", + "」": "'", + "《": "'", + "》": "'", + "—": "", + "~": "", + "~": "", + ":": ",", + ";": ",", + ";": ",", + ":": ",", } REPLACE_SYMBOL_REGEX = re.compile( @@ -21,6 +35,17 @@ ) +EMOJI_REGEX = re.compile( + "[" + "\U0001F600-\U0001F64F" # emoticons + "\U0001F300-\U0001F5FF" # symbols & pictographs + "\U0001F680-\U0001F6FF" # transport & map symbols + "\U0001F1E0-\U0001F1FF" # flags (iOS) + "]+", + flags=re.UNICODE, +) + + def clean_text(text): # Clean the text text = text.strip() @@ -28,4 +53,10 @@ def clean_text(text): # Replace all chinese symbols with their english counterparts text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) + # Remove emojis + text = EMOJI_REGEX.sub(r"", text) + + # Remove continuous periods (...) and commas (,,,) + text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text) + return text diff --git a/xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py index 05378519db..53cf2f2317 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +++ b/xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py @@ -5,7 +5,7 @@ from .logger import RankedLogger from .logging_utils import log_hyperparameters from .rich_utils import enforce_tags, print_config_tree -from .utils import extras, get_metric_value, task_wrapper +from .utils import extras, get_metric_value, set_seed, task_wrapper __all__ = [ "enforce_tags", @@ -20,4 +20,5 @@ "braceexpand", "get_latest_checkpoint", "autocast_exclude_mps", + "set_seed", ] diff --git a/xinference/thirdparty/fish_speech/fish_speech/utils/utils.py b/xinference/thirdparty/fish_speech/fish_speech/utils/utils.py index c546bfa1ed..5a34bdcfed 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +++ b/xinference/thirdparty/fish_speech/fish_speech/utils/utils.py @@ -1,7 +1,10 @@ +import random import warnings from importlib.util import find_spec from typing import Callable +import numpy as np +import torch from omegaconf import DictConfig from .logger import RankedLogger @@ -112,3 +115,22 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float: log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") return metric_value + + +def set_seed(seed: int): + if seed < 0: + seed = -seed + if seed > (1 << 31): + seed = 1 << 31 + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if torch.backends.cudnn.is_available(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py b/xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py b/xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py index 2f57b595a2..790c0e632c 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +++ b/xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py @@ -114,7 +114,7 @@ def __init__( block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", - button_shadow="*shadow_drop_lg", + # button_shadow="*shadow_drop_lg", button_small_padding="0px", button_large_padding="3px", ) diff --git a/xinference/thirdparty/fish_speech/fish_speech/webui/manage.py b/xinference/thirdparty/fish_speech/fish_speech/webui/manage.py index 4ec3fcac25..c21233eee3 100644 --- a/xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +++ b/xinference/thirdparty/fish_speech/fish_speech/webui/manage.py @@ -794,7 +794,7 @@ def llama_quantify(llama_weight, quantify_mode): value="VQGAN", ) with gr.Row(): - with gr.Tabs(): + with gr.Column(): with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page: gr.HTML("You don't need to train this model!") diff --git a/xinference/thirdparty/fish_speech/tools/api.py b/xinference/thirdparty/fish_speech/tools/api.py index 7fcc9330ae..cc12f3a0fd 100644 --- a/xinference/thirdparty/fish_speech/tools/api.py +++ b/xinference/thirdparty/fish_speech/tools/api.py @@ -1,16 +1,16 @@ -import base64 import io -import json +import os import queue -import random -import sys +import re +import time import traceback import wave from argparse import ArgumentParser from http import HTTPStatus from pathlib import Path -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any +import librosa import numpy as np import ormsgpack # import pyrootutils @@ -28,27 +28,74 @@ # Kui, # OpenAPI, # StreamResponse, +# request, # ) # from kui.asgi.routing import MultimethodRoutes from loguru import logger -from pydantic import BaseModel, Field, conint +from transformers import AutoTokenizer # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +import struct +from threading import Lock + +import httpx +from cachetools import LRUCache, cached +from funasr import AutoModel +from silero_vad import get_speech_timestamps, load_silero_vad + +from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN +from fish_speech.models.text2semantic.llama import BaseModelArgs # from fish_speech.models.vqgan.lit_module import VQGAN from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture from fish_speech.text.chn_text_norm.text import Text as ChnNormedText -from fish_speech.utils import autocast_exclude_mps -from tools.commons import ServeReferenceAudio, ServeTTSRequest +from fish_speech.utils import autocast_exclude_mps, set_seed from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text from tools.llama.generate import ( GenerateRequest, GenerateResponse, WrappedGenerateResponse, launch_thread_safe_queue, + launch_thread_safe_queue_agent, +) +from tools.schema import ( + GLOBAL_NUM_SAMPLES, + ASRPackRequest, + ServeASRRequest, + ServeASRResponse, + ServeASRSegment, + ServeAudioPart, + ServeForwardMessage, + ServeMessage, + ServeRequest, + ServeResponse, + ServeStreamDelta, + ServeStreamResponse, + ServeTextPart, + ServeTimedASRResponse, + ServeTTSRequest, + ServeVQGANDecodeRequest, + ServeVQGANDecodeResponse, + ServeVQGANEncodeRequest, + ServeVQGANEncodeResponse, + ServeVQPart, ) from tools.vqgan.inference import load_model as load_decoder_model +global_lock = Lock() + +# Whether to disable keepalive (which is helpful if the server is in the same cluster) +DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true" +async_client = httpx.AsyncClient( + timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None) +) +backends = torchaudio.list_audio_backends() + +if "ffmpeg" in backends: + backend = "ffmpeg" +else: + backend = "soundfile" + def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): buffer = io.BytesIO() @@ -91,9 +138,7 @@ def load_audio(reference_audio, sr): audio_data = reference_audio reference_audio = io.BytesIO(audio_data) - waveform, original_sr = torchaudio.load( - reference_audio, backend="sox" if sys.platform == "linux" else "soundfile" - ) + waveform, original_sr = torchaudio.load(reference_audio, backend=backend) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) @@ -167,9 +212,390 @@ def get_content_type(audio_format): return "application/octet-stream" +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def batch_encode(model, audios: list[bytes | torch.Tensor]): + audios = [ + ( + torch.from_numpy( + librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] + )[None] + if isinstance(audio, bytes) + else audio + ) + for audio in audios + ] + + # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios): + # raise ValueError("Single audio length is too long (>120s)") + + max_length = max(audio.shape[-1] for audio in audios) + print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") + + lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) + max_length = lengths.max().item() + padded = torch.stack( + [ + torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1])) + for audio in audios + ] + ).to(model.device) + + features, feature_lengths = model.encode(padded, audio_lengths=lengths) + features, feature_lengths = features.cpu(), feature_lengths.cpu() + + return [feature[..., :length] for feature, length in zip(features, feature_lengths)] + + +@cached( + cache=LRUCache(maxsize=10000), + key=lambda model, audios: (model.device, tuple(audios)), +) +def cached_vqgan_batch_encode(model, audios: list[bytes]): + return batch_encode(model, audios) + + +# @routes.http.post("/v1/vqgan/encode") +# def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): +# +# start_time = time.time() +# tokens = cached_vqgan_batch_encode(decoder_model, payload.audios) +# logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms") +# +# return ormsgpack.packb( +# ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), +# option=ormsgpack.OPT_SERIALIZE_PYDANTIC, +# ) + + +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def vqgan_decode(model, features): + lengths = torch.tensor( + [feature.shape[-1] for feature in features], device=model.device + ) + max_length = lengths.max().item() + padded = torch.stack( + [ + torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) + for feature in features + ] + ).to(model.device) + + # If bs too large, we do micro batch decode + audios, audio_lengths = [], [] + for i in range(0, padded.shape[0], 8): + audio, audio_length = model.decode( + padded[i : i + 8], feature_lengths=lengths[i : i + 8] + ) + audios.append(audio) + audio_lengths.append(audio_length) + audios = torch.cat(audios, dim=0) + audio_lengths = torch.cat(audio_lengths, dim=0) + audios, audio_lengths = audios.cpu(), audio_lengths.cpu() + + return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] + + +# @routes.http.post("/v1/vqgan/decode") +# def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): +# tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens] +# start_time = time.time() +# audios = vqgan_decode(decoder_model, tokens) +# logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms") +# audios = [audio.astype(np.float16).tobytes() for audio in audios] +# return ormsgpack.packb( +# ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC +# ) + + +@torch.no_grad() +def batch_asr(model, audios, sr, language="auto"): + resampled_audios = [] + for audio in audios: + audio = torchaudio.functional.resample(audio, sr, 16000) + assert audio.ndim == 1 + resampled_audios.append(audio) + + with global_lock: + res = model.generate( + input=resampled_audios, + batch_size=len(resampled_audios), + language=language, + use_itn=True, + ) + + results = [] + for r, audio in zip(res, audios): + text = r["text"] + text = re.sub(r"<\|.*?\|>", "", text) + duration = len(audio) / sr * 1000 + huge_gap = False + + if "timestamp" in r and len(r["timestamp"]) > 2: + for timestamp_a, timestamp_b in zip( + r["timestamp"][:-1], r["timestamp"][1:] + ): + # If there is a gap of more than 5 seconds, we consider it as a huge gap + if timestamp_b[0] - timestamp_a[1] > 5000: + huge_gap = True + break + + # Doesn't make sense to have a huge gap at the end + if duration - r["timestamp"][-1][1] > 3000: + huge_gap = True + + results.append( + { + "text": text, + "duration": duration, + "huge_gap": huge_gap, + } + ) + + return results + + +# @routes.http.post("/v1/asr") +# def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]): +# start_time = time.time() +# audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios] +# audios = [torch.from_numpy(audio).float() for audio in audios] +# +# if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios): +# raise HTTPException(status_code=400, detail="Audio length is too long") +# +# transcriptions = batch_asr( +# asr_model, audios=audios, sr=payload.sample_rate, language=payload.language +# ) +# logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms") +# +# return ormsgpack.packb( +# ServeASRResponse(transcriptions=transcriptions), +# option=ormsgpack.OPT_SERIALIZE_PYDANTIC, +# ) + + +from fish_speech.conversation import Conversation, Message + + +def execute_request( + input_queue: queue.Queue, + tokenizer: AutoTokenizer, + config: BaseModelArgs, + request: ServeRequest, + device: str = "cuda:0", +): + semantic_id, im_end_id = tokenizer.convert_tokens_to_ids( + [SEMANTIC_TOKEN, IM_END_TOKEN] + ) + messages = [] + for message in request.messages: + messages.append(message.to_conversation_message()) + + assert len(messages) >= 1, "At least one message is required" + # assert messages[-1].role == "user", "The last message must be from the user" + + if messages[-1].role == "user": + messages.append(Message(role="assistant", parts=[], add_im_end=False)) + else: + assert ( + messages[-1].role == "assistant" + ), "The last message must be from the assistant" + messages[-1].add_im_end = False + + conv = Conversation(messages=messages) + prompt = conv.encode_for_inference( + tokenizer=tokenizer, num_codebooks=config.num_codebooks + ).to(device) + + if request.streaming: + for i in range(request.num_samples): + yield ServeStreamResponse( + sample_id=i, + delta=ServeStreamDelta( + role="assistant", + ), + ) + + req = { + "prompt": prompt, + "max_new_tokens": request.max_new_tokens, + "im_end_id": im_end_id, + "semantic_id": semantic_id, + "temperature": request.temperature, + "top_p": request.top_p, + "repetition_penalty": request.repetition_penalty, + "num_samples": request.num_samples, + "early_stop_threshold": request.early_stop_threshold, + } + + start = time.time() + response_queue = queue.Queue() + input_queue.put(GenerateRequest(req, response_queue)) + + # Decoding + decode_buffer = [[] for _ in range(request.num_samples)] + parts = [[] for _ in range(request.num_samples)] + + def send_reset_buffer(sample_id): + nonlocal decode_buffer + if len(decode_buffer[sample_id]) == 0: + return + + decoded = tokenizer.decode(decode_buffer[sample_id]) + part = ServeTextPart(text=decoded) + + if request.streaming: + yield ServeStreamResponse(delta=ServeStreamDelta(part=part)) + else: + parts[sample_id].append(part) + + decode_buffer[sample_id] = [] + + # Decode process + finished = [False for _ in range(request.num_samples)] + stats = {} + idx = 0 + while True: + response = response_queue.get() + + if response in ["stop", "error"]: + break + + for sample_id, tokens in enumerate(response): + if finished[sample_id]: + continue + + if tokens[0] == im_end_id: + finished[sample_id] = True + if request.streaming: + yield from send_reset_buffer(sample_id) + yield ServeStreamResponse( + sample_id=sample_id, + finish_reason="stop", + stats=stats, + ) + continue + + if tokens[0] == semantic_id and request.streaming: + yield from send_reset_buffer(sample_id) + # Streaming vq + _tokens = tokens[1:].clone() - 1 + + if config.share_codebook_embeddings is False: + for i in range(len(_tokens)): + _tokens[i] -= config.codebook_size * i + + yield ServeStreamResponse( + sample_id=sample_id, + delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), + ) + continue + + # Not streaming vq + if tokens[0] == semantic_id: + yield from send_reset_buffer(sample_id) + # None streaming vq + if len(parts[sample_id]) == 0 or not isinstance( + parts[sample_id][-1], ServeVQPart + ): + _tokens = tokens[1:].clone() - 1 + + if config.share_codebook_embeddings is False: + for i in range(len(_tokens)): + _tokens[i] -= config.codebook_size * i + + parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) + else: + for codebook_id, value in enumerate(tokens[1:, :]): + val = value.item() - 1 + if config.share_codebook_embeddings is False: + val -= config.codebook_size * codebook_id + + parts[sample_id][-1].codes[codebook_id].append(val) + continue + + if tokens[0] != semantic_id: + # Stream text decode is not supported now + decode_buffer[sample_id].append(tokens[0, 0]) + + if idx == 0: + stats["time_to_first_token"] = (time.time() - start) * 1000 + + idx += 1 + + for sample_id in range(request.num_samples): + yield from send_reset_buffer(sample_id) + + stats["total_time"] = (time.time() - start) * 1000 + stats["total_tokens"] = idx + + if request.streaming: + for sample_id in range(request.num_samples): + if finished[sample_id]: + continue + yield ServeStreamResponse( + finish_reason=response, stats=stats, sample_id=sample_id + ) + return + + yield ServeResponse( + messages=[ + ServeMessage(role="assistant", parts=parts[i]) + for i in range(request.num_samples) + ], + finish_reason=response, + stats=stats, + ) + + +# @routes.http.post("/v1/chat") +# def api_invoke_chat( +# req: Annotated[ServeRequest, Body(exclusive=True)], +# ): +# """ +# Invoke model and generate audio +# """ +# +# # This makes torch compile happy +# assert ( +# req.num_samples == GLOBAL_NUM_SAMPLES +# ), f"num_samples must be {GLOBAL_NUM_SAMPLES}" +# +# content_type = request.headers.get("Content-Type", "application/json") +# json_mode = "application/json" in content_type +# +# async def wrapped_generator(): +# generator = execute_request(llama_queue, tokenizer, config, req, args.device) +# +# for i in generator: +# if json_mode: +# body = i.model_dump_json().encode("utf-8") +# yield b"data: " + body + b"\n\n" +# else: +# body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) +# yield struct.pack("I", len(body)) + body +# +# # Naive mode +# if req.streaming is False: +# result = next(execute_request(llama_queue, tokenizer, config, req, args.device)) +# +# if json_mode: +# return JSONResponse(result.model_dump()) +# else: +# return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) +# +# return StreamResponse( +# iterable=wrapped_generator(), content_type="text/event-stream" +# ) + + @torch.inference_mode() def inference(req: ServeTTSRequest): + global prompt_tokens, prompt_texts + idstr: str | None = req.reference_id if idstr is not None: ref_folder = Path("references") / idstr @@ -177,33 +603,47 @@ def inference(req: ServeTTSRequest): ref_audios = list_files( ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False ) - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=audio_to_bytes(str(ref_audio)), - enable_reference_audio=True, - ) - for ref_audio in ref_audios - ] - prompt_texts = [ - read_ref_text(str(ref_audio.with_suffix(".lab"))) - for ref_audio in ref_audios - ] + + if req.use_memory_cache == "never" or ( + req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 + ): + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + else: + logger.info("Use same references") else: # Parse reference audio aka prompt refs = req.references - if refs is None: - refs = [] - prompt_tokens = [ - encode_reference( - decoder_model=decoder_model, - reference_audio=ref.audio, - enable_reference_audio=True, - ) - for ref in refs - ] - prompt_texts = [ref.text for ref in refs] + + if req.use_memory_cache == "never" or ( + req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0 + ): + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=ref.audio, + enable_reference_audio=True, + ) + for ref in refs + ] + prompt_texts = [ref.text for ref in refs] + else: + logger.info("Use same references") + + if req.seed is not None: + set_seed(req.seed) + logger.warning(f"set seed: {req.seed}") # LLAMA Inference request = dict( @@ -220,7 +660,7 @@ def inference(req: ServeTTSRequest): compile=args.compile, iterative_prompt=req.chunk_length > 0, chunk_length=req.chunk_length, - max_length=2048, + max_length=4096, prompt_tokens=prompt_tokens, prompt_text=prompt_texts, ) @@ -342,6 +782,8 @@ async def buffer_to_async_generator(buffer): def parse_args(): parser = ArgumentParser() + parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") + parser.add_argument("--load-asr-model", action="store_true") parser.add_argument( "--llama-checkpoint-path", type=str, @@ -367,18 +809,26 @@ def parse_args(): # openapi = OpenAPI( # { # "title": "Fish Speech API", +# "version": "1.4.2", # }, # ).routes # # # class MsgPackRequest(HttpRequest): -# async def data(self) -> Annotated[Any, ContentType("application/msgpack")]: +# async def data( +# self, +# ) -> Annotated[ +# Any, ContentType("application/msgpack"), ContentType("application/json") +# ]: # if self.content_type == "application/msgpack": # return ormsgpack.unpackb(await self.body) # +# elif self.content_type == "application/json": +# return await self.json +# # raise HTTPException( # HTTPStatus.UNSUPPORTED_MEDIA_TYPE, -# headers={"Accept": "application/msgpack"}, +# headers={"Accept": "application/msgpack, application/json"}, # ) # # @@ -393,48 +843,101 @@ def parse_args(): # ) -if __name__ == "__main__": +def load_asr_model(*, device="cuda", hub="ms"): + return AutoModel( + model="iic/SenseVoiceSmall", + device=device, + disable_pbar=True, + hub=hub, + ) - import uvicorn - args = parse_args() - args.precision = torch.half if args.half else torch.bfloat16 +# Each worker process created by Uvicorn has its own memory space, +# meaning that models and variables are not shared between processes. +# Therefore, any global variables (like `llama_queue` or `decoder_model`) +# will not be shared across workers. - logger.info("Loading Llama model...") - llama_queue = launch_thread_safe_queue( - checkpoint_path=args.llama_checkpoint_path, - device=args.device, - precision=args.precision, - compile=args.compile, - ) - logger.info("Llama model loaded, loading VQ-GAN model...") - decoder_model = load_decoder_model( - config_name=args.decoder_config_name, - checkpoint_path=args.decoder_checkpoint_path, - device=args.device, - ) +# Multi-threading for deep learning can cause issues, such as inconsistent +# outputs if multiple threads access the same buffers simultaneously. +# Instead, it's better to use multiprocessing or independent models per thread. +# @app.on_startup +# def initialize_app(app: Kui): +# +# global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts +# +# prompt_tokens, prompt_texts = [], [] +# +# args = parse_args() # args same as ones in other processes +# args.precision = torch.half if args.half else torch.bfloat16 +# +# if args.load_asr_model: +# logger.info(f"Loading ASR model...") +# asr_model = load_asr_model(device=args.device) +# +# logger.info("Loading Llama model...") +# +# if args.mode == "tts": +# llama_queue = launch_thread_safe_queue( +# checkpoint_path=args.llama_checkpoint_path, +# device=args.device, +# precision=args.precision, +# compile=args.compile, +# ) +# else: +# llama_queue, tokenizer, config = launch_thread_safe_queue_agent( +# checkpoint_path=args.llama_checkpoint_path, +# device=args.device, +# precision=args.precision, +# compile=args.compile, +# ) +# +# logger.info("Llama model loaded, loading VQ-GAN model...") +# +# decoder_model = load_decoder_model( +# config_name=args.decoder_config_name, +# checkpoint_path=args.decoder_checkpoint_path, +# device=args.device, +# ) +# +# logger.info("VQ-GAN model loaded, warming up...") +# +# vad_model = load_silero_vad() +# +# logger.info("VAD model loaded, warming up...") +# +# if args.mode == "tts": +# # Dry run to ensure models work and avoid first-time latency +# list( +# inference( +# ServeTTSRequest( +# text="Hello world.", +# references=[], +# reference_id=None, +# max_new_tokens=0, +# chunk_length=200, +# top_p=0.7, +# repetition_penalty=1.2, +# temperature=0.7, +# emotion=None, +# format="wav", +# ) +# ) +# ) +# +# logger.info(f"Warming up done, starting server at http://{args.listen}") - logger.info("VQ-GAN model loaded, warming up...") - - # Dry run to check if the model is loaded correctly and avoid the first-time latency - list( - inference( - ServeTTSRequest( - text="Hello world.", - references=[], - reference_id=None, - max_new_tokens=1024, - chunk_length=200, - top_p=0.7, - repetition_penalty=1.2, - temperature=0.7, - emotion=None, - format="wav", - ) - ) - ) - logger.info(f"Warming up done, starting server at http://{args.listen}") +if __name__ == "__main__": + + import uvicorn + + args = parse_args() host, port = args.listen.split(":") - uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info") + uvicorn.run( + "tools.api:app", + host=host, + port=int(port), + workers=args.workers, + log_level="info", + ) diff --git a/xinference/thirdparty/fish_speech/tools/commons.py b/xinference/thirdparty/fish_speech/tools/commons.py deleted file mode 100644 index f81cadec1e..0000000000 --- a/xinference/thirdparty/fish_speech/tools/commons.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Annotated, Literal, Optional - -from pydantic import BaseModel, Field, conint - - -class ServeReferenceAudio(BaseModel): - audio: bytes - text: str - - -class ServeTTSRequest(BaseModel): - text: str - chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 - # Audio format - format: Literal["wav", "pcm", "mp3"] = "wav" - mp3_bitrate: Literal[64, 128, 192] = 128 - # References audios for in-context learning - references: list[ServeReferenceAudio] = [] - # Reference id - # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ - # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 - reference_id: str | None = None - # Normalize text for en & zh, this increase stability for numbers - normalize: bool = True - mp3_bitrate: Optional[int] = 64 - opus_bitrate: Optional[int] = -1000 - # Balance mode will reduce latency to 300ms, but may decrease stability - latency: Literal["normal", "balanced"] = "normal" - # not usually used below - streaming: bool = False - emotion: Optional[str] = None - max_new_tokens: int = 1024 - top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 - repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 - temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 diff --git a/xinference/thirdparty/fish_speech/tools/e2e_webui.py b/xinference/thirdparty/fish_speech/tools/e2e_webui.py new file mode 100644 index 0000000000..37474fbd56 --- /dev/null +++ b/xinference/thirdparty/fish_speech/tools/e2e_webui.py @@ -0,0 +1,232 @@ +import io +import re +import wave + +import gradio as gr +import numpy as np + +from .fish_e2e import FishE2EAgent, FishE2EEventType +from .schema import ServeMessage, ServeTextPart, ServeVQPart + + +def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + return wav_header_bytes + + +class ChatState: + def __init__(self): + self.conversation = [] + self.added_systext = False + self.added_sysaudio = False + + def get_history(self): + results = [] + for msg in self.conversation: + results.append({"role": msg.role, "content": self.repr_message(msg)}) + + # Process assistant messages to extract questions and update user messages + for i, msg in enumerate(results): + if msg["role"] == "assistant": + match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"]) + if match and i > 0 and results[i - 1]["role"] == "user": + # Update previous user message with extracted question + results[i - 1]["content"] += "\n" + match.group(1) + # Remove the Question/Answer format from assistant message + msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1] + return results + + def repr_message(self, msg: ServeMessage): + response = "" + for part in msg.parts: + if isinstance(part, ServeTextPart): + response += part.text + elif isinstance(part, ServeVQPart): + response += f"