diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml
index 3fedc5eea5..81a16122c5 100644
--- a/.github/workflows/python.yaml
+++ b/.github/workflows/python.yaml
@@ -162,7 +162,7 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pip install -U WeTextProcessing<1.0.4
${{ env.SELF_HOST_PYTHON }} -m pip install -U librosa
${{ env.SELF_HOST_PYTHON }} -m pip install -U xxhash
- ${{ env.SELF_HOST_PYTHON }} -m pip install -U "ChatTTS>0.1,<0.2"
+ ${{ env.SELF_HOST_PYTHON }} -m pip install -U "ChatTTS>=0.2"
${{ env.SELF_HOST_PYTHON }} -m pip install -U HyperPyYAML
${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y matcha-tts
${{ env.SELF_HOST_PYTHON }} -m pip install -U onnxruntime-gpu==1.16.0; sys_platform == 'linux'
@@ -175,9 +175,13 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y opencc
${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y "faster_whisper"
${{ env.SELF_HOST_PYTHON }} -m pip install -U accelerate
+ ${{ env.SELF_HOST_PYTHON }} -m pip install -U verovio
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_stable_diffusion.py && \
+ ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
+ -W ignore::PendingDeprecationWarning \
+ --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_got_ocr2.py && \
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_whisper.py && \
@@ -203,6 +207,6 @@ jobs:
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/client/tests/test_client.py
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
- --cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/audio/tests xinference
+ --cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/image/tests/test_got_ocr2.py --ignore xinference/model/audio/tests xinference
fi
working-directory: .
diff --git a/README.md b/README.md
index cd622e7b07..d63cbb42ff 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
Xinference Cloud ·
- Xinference Enterprise ·
+ Xinference Enterprise ·
Self-hosting ·
Documentation
diff --git a/README_zh_CN.md b/README_zh_CN.md
index 2ace2e1bcc..2df28e2632 100644
--- a/README_zh_CN.md
+++ b/README_zh_CN.md
@@ -5,7 +5,7 @@
Xinference 云服务 ·
- Xinference 企业版 ·
+ Xinference 企业版 ·
自托管 ·
文档
diff --git a/doc/source/models/builtin/embedding/gte-qwen2.rst b/doc/source/models/builtin/embedding/gte-qwen2.rst
index a88fdece9d..85eeeac39a 100644
--- a/doc/source/models/builtin/embedding/gte-qwen2.rst
+++ b/doc/source/models/builtin/embedding/gte-qwen2.rst
@@ -11,11 +11,11 @@ gte-Qwen2
Specifications
^^^^^^^^^^^^^^
-- **Dimensions:** 3584
+- **Dimensions:** 4096
- **Max Tokens:** 32000
- **Model ID:** Alibaba-NLP/gte-Qwen2-7B-instruct
- **Model Hubs**: `Hugging Face `__, `ModelScope `__
Execute the following command to launch the model::
- xinference launch --model-name gte-Qwen2 --model-type embedding
\ No newline at end of file
+ xinference launch --model-name gte-Qwen2 --model-type embedding
diff --git a/setup.cfg b/setup.cfg
index cd3a752de6..3c08363e59 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -33,7 +33,7 @@ install_requires =
tabulate
requests
pydantic
- fastapi==0.110.3
+ fastapi>=0.110.3
uvicorn
huggingface-hub>=0.19.4
typing_extensions
@@ -80,13 +80,13 @@ all =
llama-cpp-python>=0.2.25,!=0.2.58
transformers>=4.43.2
torch>=2.0.0 # >=2.0 For CosyVoice
- accelerate>=0.27.2
+ accelerate>=0.28.0
sentencepiece
transformers_stream_generator
bitsandbytes
protobuf
einops
- tiktoken
+ tiktoken>=0.6.0
sentence-transformers>=3.1.0
vllm>=0.2.6 ; sys_platform=='linux'
diffusers>=0.30.0
@@ -110,7 +110,7 @@ all =
librosa # For ChatTTS
xxhash # For ChatTTS
torchaudio # For ChatTTS
- ChatTTS>0.1,<0.2
+ ChatTTS>=0.2
lightning>=2.0.0 # For CosyVoice, matcha
hydra-core>=1.3.2 # For CosyVoice, matcha
inflect # For CosyVoice, matcha
@@ -131,6 +131,8 @@ all =
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
+ verovio>=4.3.1 # For got_ocr2
+ accelerate>=0.28.0 # For got_ocr2
intel =
torch==2.1.0a0
intel_extension_for_pytorch==2.1.10+xpu
@@ -139,7 +141,7 @@ llama_cpp =
transformers =
transformers>=4.43.2
torch
- accelerate>=0.27.2
+ accelerate>=0.28.0
sentencepiece
transformers_stream_generator
bitsandbytes
@@ -174,6 +176,12 @@ image =
diffusers>=0.30.0 # fix conflict with matcha-tts
controlnet_aux
deepcache
+ verovio>=4.3.1 # For got_ocr2
+ transformers>=4.37.2 # For got_ocr2
+ tiktoken>=0.6.0 # For got_ocr2
+ accelerate>=0.28.0 # For got_ocr2
+ torch # For got_ocr2
+ torchvision # For got_ocr2
video =
diffusers>=0.30.0
imageio-ffmpeg
@@ -185,7 +193,7 @@ audio =
librosa
xxhash
torchaudio
- ChatTTS>0.1,<0.2
+ ChatTTS>=0.2
tiktoken # For CosyVoice, openai-whisper
torch>=2.0.0 # For CosyVoice, matcha
lightning>=2.0.0 # For CosyVoice, matcha
diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py
index 56d6c89f2d..ed3a2eab90 100644
--- a/xinference/api/restful_api.py
+++ b/xinference/api/restful_api.py
@@ -567,6 +567,16 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
+ self._router.add_api_route(
+ "/v1/images/ocr",
+ self.create_ocr,
+ methods=["POST"],
+ dependencies=(
+ [Security(self._auth_service, scopes=["models:read"])]
+ if self.is_authenticated()
+ else None
+ ),
+ )
# SD WebUI API
self._router.add_api_route(
"/sdapi/v1/options",
@@ -1754,6 +1764,44 @@ async def create_inpainting(
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))
+ async def create_ocr(
+ self,
+ model: str = Form(...),
+ image: UploadFile = File(media_type="application/octet-stream"),
+ kwargs: Optional[str] = Form(None),
+ ) -> Response:
+ model_uid = model
+ try:
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
+ except ValueError as ve:
+ logger.error(str(ve), exc_info=True)
+ await self._report_error_event(model_uid, str(ve))
+ raise HTTPException(status_code=400, detail=str(ve))
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ await self._report_error_event(model_uid, str(e))
+ raise HTTPException(status_code=500, detail=str(e))
+
+ try:
+ if kwargs is not None:
+ parsed_kwargs = json.loads(kwargs)
+ else:
+ parsed_kwargs = {}
+ im = Image.open(image.file)
+ text = await model_ref.ocr(
+ image=im,
+ **parsed_kwargs,
+ )
+ return Response(content=text, media_type="text/plain")
+ except RuntimeError as re:
+ logger.error(re, exc_info=True)
+ await self._report_error_event(model_uid, str(re))
+ raise HTTPException(status_code=400, detail=str(re))
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ await self._report_error_event(model_uid, str(e))
+ raise HTTPException(status_code=500, detail=str(e))
+
async def create_flexible_infer(self, request: Request) -> Response:
payload = await request.json()
diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py
index 438ead7390..dd5e3f1146 100644
--- a/xinference/client/restful/restful_client.py
+++ b/xinference/client/restful/restful_client.py
@@ -369,6 +369,25 @@ def inpainting(
response_data = response.json()
return response_data
+ def ocr(self, image: Union[str, bytes], **kwargs):
+ url = f"{self._base_url}/v1/images/ocr"
+ params = {
+ "model": self._model_uid,
+ "kwargs": json.dumps(kwargs),
+ }
+ files: List[Any] = []
+ for key, value in params.items():
+ files.append((key, (None, value)))
+ files.append(("image", ("image", image, "application/octet-stream")))
+ response = requests.post(url, files=files, headers=self.auth_headers)
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"Failed to ocr the images, detail: {_get_error_string(response)}"
+ )
+
+ response_data = response.json()
+ return response_data
+
class RESTfulVideoModelHandle(RESTfulModelHandle):
def text_to_video(
diff --git a/xinference/client/tests/test_client.py b/xinference/client/tests/test_client.py
index 5e85556d05..df28e8260a 100644
--- a/xinference/client/tests/test_client.py
+++ b/xinference/client/tests/test_client.py
@@ -99,15 +99,9 @@ def _check_stream():
for _ in range(2):
r = executor.submit(_check_stream)
results.append(r)
- # Parallel generation is not supported by llama-cpp-python.
- error_count = 0
+
for r in results:
- try:
- r.result()
- except Exception as ex:
- assert "Parallel generation" in str(ex)
- error_count += 1
- assert error_count == 1
+ r.result()
# After iteration finish, we can iterate again.
_check_stream()
@@ -143,18 +137,12 @@ def _check(stream=False):
for stream in [True, False]:
results = []
- error_count = 0
with ThreadPoolExecutor() as executor:
for _ in range(3):
r = executor.submit(_check, stream=stream)
results.append(r)
for r in results:
- try:
- r.result()
- except Exception as ex:
- assert "Parallel generation" in str(ex)
- error_count += 1
- assert error_count == (2 if stream else 0)
+ r.result()
client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0
diff --git a/xinference/core/chat_interface.py b/xinference/core/chat_interface.py
index 9de2dab252..08b30ab054 100644
--- a/xinference/core/chat_interface.py
+++ b/xinference/core/chat_interface.py
@@ -74,7 +74,11 @@ def build(self) -> "gr.Blocks":
# Gradio initiates the queue during a startup event, but since the app has already been
# started, that event will not run, so manually invoke the startup events.
# See: https://github.com/gradio-app/gradio/issues/5228
- interface.startup_events()
+ try:
+ interface.run_startup_events()
+ except AttributeError:
+ # compatibility
+ interface.startup_events()
favicon_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
os.path.pardir,
diff --git a/xinference/core/image_interface.py b/xinference/core/image_interface.py
index 56761e4101..b48636bfd5 100644
--- a/xinference/core/image_interface.py
+++ b/xinference/core/image_interface.py
@@ -63,7 +63,11 @@ def build(self) -> gr.Blocks:
# Gradio initiates the queue during a startup event, but since the app has already been
# started, that event will not run, so manually invoke the startup events.
# See: https://github.com/gradio-app/gradio/issues/5228
- interface.startup_events()
+ try:
+ interface.run_startup_events()
+ except AttributeError:
+ # compatibility
+ interface.startup_events()
favicon_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
os.path.pardir,
diff --git a/xinference/core/model.py b/xinference/core/model.py
index 206adc25d9..567ef81769 100644
--- a/xinference/core/model.py
+++ b/xinference/core/model.py
@@ -17,10 +17,10 @@
import inspect
import json
import os
+import queue
import time
import types
import uuid
-import weakref
from asyncio.queues import Queue
from asyncio.tasks import wait_for
from concurrent.futures import Future as ConcurrentFuture
@@ -32,7 +32,6 @@
Callable,
Dict,
Generator,
- Iterator,
List,
Optional,
Union,
@@ -209,9 +208,8 @@ def __init__(
model_description.to_dict() if model_description else {}
)
self._request_limits = request_limits
-
- self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {}
- self._current_generator = lambda: None
+ self._pending_requests: asyncio.Queue = asyncio.Queue()
+ self._handle_pending_requests_task = None
self._lock = (
None
if isinstance(
@@ -237,6 +235,10 @@ def __init__(
async def __post_create__(self):
self._loop = asyncio.get_running_loop()
+ self._handle_pending_requests_task = asyncio.create_task(
+ self._handle_pending_requests()
+ )
+
if self.allow_batching():
from .scheduler import SchedulerActor
@@ -474,6 +476,43 @@ async def _to_async_gen(self, output_type: str, gen: types.AsyncGeneratorType):
)
await asyncio.gather(*coros)
+ async def _handle_pending_requests(self):
+ logger.info("Start requests handler.")
+ while True:
+ gen, stream_out, stop = await self._pending_requests.get()
+
+ async def _async_wrapper(_gen):
+ try:
+ # anext is only available for Python >= 3.10
+ return await _gen.__anext__() # noqa: F821
+ except StopAsyncIteration:
+ return stop
+
+ def _wrapper(_gen):
+ # Avoid issue: https://github.com/python/cpython/issues/112182
+ try:
+ return next(_gen)
+ except StopIteration:
+ return stop
+
+ while True:
+ try:
+ if inspect.isgenerator(gen):
+ r = await asyncio.to_thread(_wrapper, gen)
+ elif inspect.isasyncgen(gen):
+ r = await _async_wrapper(gen)
+ else:
+ raise Exception(
+ f"The generator {gen} should be a generator or an async generator, "
+ f"but a {type(gen)} is got."
+ )
+ stream_out.put_nowait(r)
+ if r is not stop:
+ continue
+ except Exception:
+ logger.exception("stream encountered an error.")
+ break
+
async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
return await self._call_wrapper("json", fn, *args, **kwargs)
@@ -487,6 +526,13 @@ async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
ret = await fn(*args, **kwargs)
else:
ret = await asyncio.to_thread(fn, *args, **kwargs)
+
+ if inspect.isgenerator(ret):
+ gen = self._to_generator(output_type, ret)
+ return gen
+ if inspect.isasyncgen(ret):
+ gen = self._to_async_gen(output_type, ret)
+ return gen
else:
async with self._lock:
if inspect.iscoroutinefunction(fn):
@@ -494,17 +540,40 @@ async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
else:
ret = await asyncio.to_thread(fn, *args, **kwargs)
- if self._lock is not None and self._current_generator():
- raise Exception("Parallel generation is not supported by llama-cpp-python.")
+ stream_out: Union[queue.Queue, asyncio.Queue]
+
+ if inspect.isgenerator(ret):
+ gen = self._to_generator(output_type, ret)
+ stream_out = queue.Queue()
+ stop = object()
+ self._pending_requests.put_nowait((gen, stream_out, stop))
+
+ def _stream_out_generator():
+ while True:
+ o = stream_out.get()
+ if o is stop:
+ break
+ else:
+ yield o
+
+ return _stream_out_generator()
+
+ if inspect.isasyncgen(ret):
+ gen = self._to_async_gen(output_type, ret)
+ stream_out = asyncio.Queue()
+ stop = object()
+ self._pending_requests.put_nowait((gen, stream_out, stop))
+
+ async def _stream_out_async_gen():
+ while True:
+ o = await stream_out.get()
+ if o is stop:
+ break
+ else:
+ yield o
+
+ return _stream_out_async_gen()
- if inspect.isgenerator(ret):
- gen = self._to_generator(output_type, ret)
- self._current_generator = weakref.ref(gen)
- return gen
- if inspect.isasyncgen(ret):
- gen = self._to_async_gen(output_type, ret)
- self._current_generator = weakref.ref(gen)
- return gen
if output_type == "json":
return await asyncio.to_thread(json_dumps, ret)
else:
@@ -592,7 +661,6 @@ async def handle_batching_request(
prompt_or_messages, queue, call_ability, *args, **kwargs
)
gen = self._to_async_gen("json", ret)
- self._current_generator = weakref.ref(gen)
return gen
else:
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
@@ -953,6 +1021,25 @@ async def inpainting(
f"Model {self._model.model_spec} is not for creating image."
)
+ @log_async(
+ logger=logger,
+ ignore_kwargs=["image"],
+ )
+ async def ocr(
+ self,
+ image: "PIL.Image",
+ *args,
+ **kwargs,
+ ):
+ if hasattr(self._model, "ocr"):
+ return await self._call_wrapper_json(
+ self._model.ocr,
+ image,
+ *args,
+ **kwargs,
+ )
+ raise AttributeError(f"Model {self._model.model_spec} is not for ocr.")
+
@request_limit
@log_async(logger=logger, ignore_kwargs=["image"])
async def infer(
@@ -994,3 +1081,6 @@ async def text_to_video(
async def record_metrics(self, name, op, kwargs):
worker_ref = await self._get_worker_ref()
await worker_ref.record_metrics(name, op, kwargs)
+
+ async def get_pending_requests_count(self):
+ return self._pending_requests.qsize()
diff --git a/xinference/core/scheduler.py b/xinference/core/scheduler.py
index 1b91d62e27..8b91855daa 100644
--- a/xinference/core/scheduler.py
+++ b/xinference/core/scheduler.py
@@ -79,7 +79,7 @@ def __init__(
# For tool call
self.tools = None
# Currently, for storing tool call streaming results.
- self.outputs: List[str] = []
+ self.outputs: List[str] = [] # type: ignore
# inference results,
# it is a list type because when stream=True,
# self.completion contains all the results in a decode round.
diff --git a/xinference/core/tests/test_model.py b/xinference/core/tests/test_model.py
new file mode 100644
index 0000000000..655debf799
--- /dev/null
+++ b/xinference/core/tests/test_model.py
@@ -0,0 +1,108 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+
+import pytest
+import pytest_asyncio
+import xoscar as xo
+from xoscar import create_actor_pool
+
+from ..model import ModelActor
+
+TEST_EVENT = None
+TEST_VALUE = None
+
+
+class MockModel:
+ async def generate(self, prompt, **kwargs):
+ global TEST_VALUE
+ TEST_VALUE = True
+ assert isinstance(TEST_EVENT, asyncio.Event)
+ await TEST_EVENT.wait()
+ yield {"test1": prompt}
+ yield {"test2": prompt}
+
+
+class MockModelActor(ModelActor):
+ def __init__(
+ self,
+ supervisor_address: str,
+ worker_address: str,
+ ):
+ super().__init__(supervisor_address, worker_address, MockModel()) # type: ignore
+ self._lock = asyncio.locks.Lock()
+
+ async def __pre_destroy__(self):
+ pass
+
+ async def record_metrics(self, name, op, kwargs):
+ pass
+
+
+@pytest_asyncio.fixture
+async def setup_pool():
+ pool = await create_actor_pool(
+ f"test://127.0.0.1:{xo.utils.get_next_port()}", n_process=0
+ )
+ async with pool:
+ yield pool
+
+
+@pytest.mark.asyncio
+async def test_concurrent_call(setup_pool):
+ pool = setup_pool
+ addr = pool.external_address
+
+ global TEST_EVENT
+ TEST_EVENT = asyncio.Event()
+
+ worker: xo.ActorRefType[MockModelActor] = await xo.create_actor( # type: ignore
+ MockModelActor,
+ address=addr,
+ uid=MockModelActor.default_uid(),
+ supervisor_address="test:123",
+ worker_address="test:345",
+ )
+
+ await worker.generate("test_prompt1")
+ assert TEST_VALUE is not None
+ # This request is waiting for the TEST_EVENT, so the queue is empty.
+ pending_count = await worker.get_pending_requests_count()
+ assert pending_count == 0
+ await worker.generate("test_prompt3")
+ # This request is waiting in the queue because the previous request is waiting for TEST_EVENT.
+ pending_count = await worker.get_pending_requests_count()
+ assert pending_count == 1
+
+ async def _check():
+ gen = await worker.generate("test_prompt2")
+ result = []
+ async for g in gen:
+ result.append(g)
+ assert result == [
+ b'data: {"test1": "test_prompt2"}\r\n\r\n',
+ b'data: {"test2": "test_prompt2"}\r\n\r\n',
+ ]
+
+ check_task = asyncio.create_task(_check())
+ await asyncio.sleep(2)
+ assert not check_task.done()
+ # Pending 2 requests: test_prompt3 and test_prompt2
+ pending_count = await worker.get_pending_requests_count()
+ assert pending_count == 2
+ TEST_EVENT.set()
+ await check_task
+ pending_count = await worker.get_pending_requests_count()
+ assert pending_count == 0
diff --git a/xinference/deploy/docker/requirements.txt b/xinference/deploy/docker/requirements.txt
index 6de778c5b9..a3aa0a5e93 100644
--- a/xinference/deploy/docker/requirements.txt
+++ b/xinference/deploy/docker/requirements.txt
@@ -8,7 +8,7 @@ tqdm>=4.27
tabulate
requests
pydantic
-fastapi==0.110.3
+fastapi>=0.110.3
uvicorn
huggingface-hub>=0.19.4
typing_extensions
@@ -24,14 +24,14 @@ peft
opencv-contrib-python-headless
# all
-transformers>=4.34.1
-accelerate>=0.27.2
+transformers>=4.43.2
+accelerate>=0.28.0
sentencepiece
transformers_stream_generator
bitsandbytes
protobuf
einops
-tiktoken
+tiktoken>=0.6.0
sentence-transformers>=3.1.0
diffusers>=0.30.0
controlnet_aux
@@ -49,7 +49,7 @@ nemo_text_processing<1.1.0 # 1.1.0 requires pynini==2.1.6.post1
WeTextProcessing<1.0.4 # 1.0.4 requires pynini==2.1.6
librosa # For ChatTTS
torchaudio # For ChatTTS
-ChatTTS>0.1,<0.2
+ChatTTS>=0.2
xxhash # For ChatTTS
torch>=2.0.0 # For CosyVoice
lightning>=2.0.0 # For CosyVoice, matcha
@@ -75,6 +75,7 @@ qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
deepcache # for sd
+verovio>=4.3.1 # For got_ocr2
# sglang
outlines>=0.0.44
diff --git a/xinference/deploy/docker/requirements_cpu.txt b/xinference/deploy/docker/requirements_cpu.txt
index 50cd791995..9eb9409b4f 100644
--- a/xinference/deploy/docker/requirements_cpu.txt
+++ b/xinference/deploy/docker/requirements_cpu.txt
@@ -7,7 +7,7 @@ tqdm>=4.27
tabulate
requests
pydantic
-fastapi==0.110.3
+fastapi>=0.110.3
uvicorn
huggingface-hub>=0.19.4
typing_extensions
@@ -21,8 +21,8 @@ passlib[bcrypt]
aioprometheus[starlette]>=23.12.0
nvidia-ml-py
async-timeout
-transformers>=4.34.1
-accelerate>=0.20.3
+transformers>=4.43.2
+accelerate>=0.28.0
sentencepiece
transformers_stream_generator
bitsandbytes
@@ -46,7 +46,7 @@ nemo_text_processing<1.1.0 # 1.1.0 requires pynini==2.1.6.post1
WeTextProcessing<1.0.4 # 1.0.4 requires pynini==2.1.6
librosa # For ChatTTS
torchaudio # For ChatTTS
-ChatTTS>0.1,<0.2
+ChatTTS>=0.2
xxhash # For ChatTTS
torch>=2.0.0 # For CosyVoice
lightning>=2.0.0 # For CosyVoice, matcha
@@ -69,3 +69,4 @@ ormsgpack # For Fish Speech
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
+verovio>=4.3.1 # For got_ocr2
diff --git a/xinference/deploy/supervisor.py b/xinference/deploy/supervisor.py
index aac1e78d3e..ed12a9f7c2 100644
--- a/xinference/deploy/supervisor.py
+++ b/xinference/deploy/supervisor.py
@@ -31,10 +31,6 @@
logger = logging.getLogger(__name__)
-from ..model import _install as install_model
-
-install_model()
-
async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
logging.config.dictConfig(logging_conf) # type: ignore
diff --git a/xinference/model/audio/chattts.py b/xinference/model/audio/chattts.py
index b636fe59b4..2a5f4ee7c9 100644
--- a/xinference/model/audio/chattts.py
+++ b/xinference/model/audio/chattts.py
@@ -54,7 +54,11 @@ def load(self):
torch.set_float32_matmul_precision("high")
self._model = ChatTTS.Chat()
logger.info("Load ChatTTS model with kwargs: %s", self._kwargs)
- self._model.load(source="custom", custom_path=self._model_path, **self._kwargs)
+ ok = self._model.load(
+ source="custom", custom_path=self._model_path, **self._kwargs
+ )
+ if not ok:
+ raise Exception(f"The ChatTTS model is not correct: {self._model_path}")
def speech(
self,
@@ -114,16 +118,15 @@ def _generator():
last_pos = 0
with writer.open():
for it in iter:
- for itt in it:
- for chunk in itt:
- chunk = np.array([chunk]).transpose()
- writer.write_audio_chunk(i, torch.from_numpy(chunk))
- new_last_pos = out.tell()
- if new_last_pos != last_pos:
- out.seek(last_pos)
- encoded_bytes = out.read()
- yield encoded_bytes
- last_pos = new_last_pos
+ for chunk in it:
+ chunk = np.array([chunk]).transpose()
+ writer.write_audio_chunk(i, torch.from_numpy(chunk))
+ new_last_pos = out.tell()
+ if new_last_pos != last_pos:
+ out.seek(last_pos)
+ encoded_bytes = out.read()
+ yield encoded_bytes
+ last_pos = new_last_pos
return _generator()
else:
@@ -131,7 +134,15 @@ def _generator():
# Save the generated audio
with BytesIO() as out:
- torchaudio.save(
- out, torch.from_numpy(wavs[0]), 24000, format=response_format
- )
+ try:
+ torchaudio.save(
+ out,
+ torch.from_numpy(wavs[0]).unsqueeze(0),
+ 24000,
+ format=response_format,
+ )
+ except:
+ torchaudio.save(
+ out, torch.from_numpy(wavs[0]), 24000, format=response_format
+ )
return out.getvalue()
diff --git a/xinference/model/audio/model_spec.json b/xinference/model/audio/model_spec.json
index bf51b3da3a..e0328dd375 100644
--- a/xinference/model/audio/model_spec.json
+++ b/xinference/model/audio/model_spec.json
@@ -127,7 +127,7 @@
"model_name": "ChatTTS",
"model_family": "ChatTTS",
"model_id": "2Noise/ChatTTS",
- "model_revision": "ce5913842aebd78e4a01a02d47244b8d62ac4ee3",
+ "model_revision": "3b34118f6d25850440b8901cef3e71c6ef8619c8",
"model_ability": "text-to-audio",
"multilingual": true
},
diff --git a/xinference/model/audio/model_spec_modelscope.json b/xinference/model/audio/model_spec_modelscope.json
index e3f46f84bc..e47f1f8e3a 100644
--- a/xinference/model/audio/model_spec_modelscope.json
+++ b/xinference/model/audio/model_spec_modelscope.json
@@ -42,7 +42,7 @@
"model_name": "ChatTTS",
"model_family": "ChatTTS",
"model_hub": "modelscope",
- "model_id": "pzc163/chatTTS",
+ "model_id": "AI-ModelScope/ChatTTS",
"model_revision": "master",
"model_ability": "text-to-audio",
"multilingual": true
diff --git a/xinference/model/audio/tests/test_chattts.py b/xinference/model/audio/tests/test_chattts.py
index 93739d409e..cadd732351 100644
--- a/xinference/model/audio/tests/test_chattts.py
+++ b/xinference/model/audio/tests/test_chattts.py
@@ -46,12 +46,14 @@ def test_chattts(setup):
response = model.speech(input_string, stream=True)
assert inspect.isgenerator(response)
- i = 0
- for chunk in response:
- i += 1
- assert type(chunk) is bytes
- assert len(chunk) > 0
- assert i > 5
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as f:
+ i = 0
+ for chunk in response:
+ f.write(chunk)
+ i += 1
+ assert type(chunk) is bytes
+ assert len(chunk) > 0
+ assert i > 5
# Test openai API
import openai
diff --git a/xinference/model/embedding/model_spec.json b/xinference/model/embedding/model_spec.json
index 14d4ced519..dc8d851b85 100644
--- a/xinference/model/embedding/model_spec.json
+++ b/xinference/model/embedding/model_spec.json
@@ -233,7 +233,7 @@
},
{
"model_name": "gte-Qwen2",
- "dimensions": 3584,
+ "dimensions": 4096,
"max_tokens": 32000,
"language": ["zh", "en"],
"model_id": "Alibaba-NLP/gte-Qwen2-7B-instruct",
diff --git a/xinference/model/image/core.py b/xinference/model/image/core.py
index 8c284c9d87..581358b789 100644
--- a/xinference/model/image/core.py
+++ b/xinference/model/image/core.py
@@ -11,17 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import collections.abc
import logging
import os
+import platform
from collections import defaultdict
-from typing import Dict, List, Literal, Optional, Tuple
+from typing import Dict, List, Literal, Optional, Tuple, Union
from ...constants import XINFERENCE_CACHE_DIR
from ...types import PeftModelConfig
from ..core import CacheableModelSpec, ModelDescription
from ..utils import valid_model_revision
+from .ocr.got_ocr2 import GotOCR2Model
from .stable_diffusion.core import DiffusionModel
+from .stable_diffusion.mlx import MLXDiffusionModel
logger = logging.getLogger(__name__)
@@ -45,6 +49,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
model_hub: str = "huggingface"
model_ability: Optional[List[str]]
controlnet: Optional[List["ImageModelFamilyV1"]]
+ default_model_config: Optional[dict] = {}
default_generate_config: Optional[dict] = {}
@@ -180,6 +185,28 @@ def get_cache_status(
return valid_model_revision(meta_path, model_spec.model_revision)
+def create_ocr_model_instance(
+ subpool_addr: str,
+ devices: List[str],
+ model_uid: str,
+ model_spec: ImageModelFamilyV1,
+ model_path: Optional[str] = None,
+ **kwargs,
+) -> Tuple[GotOCR2Model, ImageModelDescription]:
+ if not model_path:
+ model_path = cache(model_spec)
+ model = GotOCR2Model(
+ model_uid,
+ model_path,
+ model_spec=model_spec,
+ **kwargs,
+ )
+ model_description = ImageModelDescription(
+ subpool_addr, devices, model_spec, model_path=model_path
+ )
+ return model, model_description
+
+
def create_image_model_instance(
subpool_addr: str,
devices: List[str],
@@ -189,8 +216,26 @@ def create_image_model_instance(
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
model_path: Optional[str] = None,
**kwargs,
-) -> Tuple[DiffusionModel, ImageModelDescription]:
+) -> Tuple[
+ Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
+]:
model_spec = match_diffusion(model_name, download_hub)
+ if model_spec.model_ability and "ocr" in model_spec.model_ability:
+ return create_ocr_model_instance(
+ subpool_addr=subpool_addr,
+ devices=devices,
+ model_uid=model_uid,
+ model_name=model_name,
+ model_spec=model_spec,
+ model_path=model_path,
+ **kwargs,
+ )
+
+ # use default model config
+ model_default_config = (model_spec.default_model_config or {}).copy()
+ model_default_config.update(kwargs)
+ kwargs = model_default_config
+
controlnet = kwargs.get("controlnet")
# Handle controlnet
if controlnet is not None:
@@ -232,10 +277,20 @@ def create_image_model_instance(
lora_load_kwargs = None
lora_fuse_kwargs = None
- model = DiffusionModel(
+ if (
+ platform.system() == "Darwin"
+ and "arm" in platform.machine().lower()
+ and model_name in MLXDiffusionModel.supported_models
+ ):
+ # Mac with M series silicon chips
+ model_cls = MLXDiffusionModel
+ else:
+ model_cls = DiffusionModel # type: ignore
+
+ model = model_cls(
model_uid,
model_path,
- lora_model_paths=lora_model,
+ lora_model=lora_model,
lora_load_kwargs=lora_load_kwargs,
lora_fuse_kwargs=lora_fuse_kwargs,
model_spec=model_spec,
diff --git a/xinference/model/image/model_spec.json b/xinference/model/image/model_spec.json
index 04386dd2e5..24933cb99e 100644
--- a/xinference/model/image/model_spec.json
+++ b/xinference/model/image/model_spec.json
@@ -8,7 +8,11 @@
"text2image",
"image2image",
"inpainting"
- ]
+ ],
+ "default_model_config": {
+ "quantize": true,
+ "quantize_text_encoder": "text_encoder_2"
+ }
},
{
"model_name": "FLUX.1-dev",
@@ -19,7 +23,11 @@
"text2image",
"image2image",
"inpainting"
- ]
+ ],
+ "default_model_config": {
+ "quantize": true,
+ "quantize_text_encoder": "text_encoder_2"
+ }
},
{
"model_name": "sd3-medium",
@@ -30,7 +38,11 @@
"text2image",
"image2image",
"inpainting"
- ]
+ ],
+ "default_model_config": {
+ "quantize": true,
+ "quantize_text_encoder": "text_encoder_3"
+ }
},
{
"model_name": "sd-turbo",
@@ -178,5 +190,14 @@
"model_ability": [
"inpainting"
]
+ },
+ {
+ "model_name": "GOT-OCR2_0",
+ "model_family": "ocr",
+ "model_id": "stepfun-ai/GOT-OCR2_0",
+ "model_revision": "cf6b7386bc89a54f09785612ba74cb12de6fa17c",
+ "model_ability": [
+ "ocr"
+ ]
}
]
diff --git a/xinference/model/image/model_spec_modelscope.json b/xinference/model/image/model_spec_modelscope.json
index b39bfc543d..ad8af7a26f 100644
--- a/xinference/model/image/model_spec_modelscope.json
+++ b/xinference/model/image/model_spec_modelscope.json
@@ -9,7 +9,11 @@
"text2image",
"image2image",
"inpainting"
- ]
+ ],
+ "default_model_config": {
+ "quantize": true,
+ "quantize_text_encoder": "text_encoder_2"
+ }
},
{
"model_name": "FLUX.1-dev",
@@ -21,7 +25,11 @@
"text2image",
"image2image",
"inpainting"
- ]
+ ],
+ "default_model_config": {
+ "quantize": true,
+ "quantize_text_encoder": "text_encoder_2"
+ }
},
{
"model_name": "sd3-medium",
@@ -33,7 +41,11 @@
"text2image",
"image2image",
"inpainting"
- ]
+ ],
+ "default_model_config": {
+ "quantize": true,
+ "quantize_text_encoder": "text_encoder_3"
+ }
},
{
"model_name": "sd-turbo",
@@ -148,5 +160,15 @@
"model_revision": "62134b9d8e703b5d6f74f1534457287a8bba77ef"
}
]
+ },
+ {
+ "model_name": "GOT-OCR2_0",
+ "model_family": "ocr",
+ "model_id": "stepfun-ai/GOT-OCR2_0",
+ "model_revision": "master",
+ "model_hub": "modelscope",
+ "model_ability": [
+ "ocr"
+ ]
}
]
diff --git a/xinference/model/image/ocr/__init__.py b/xinference/model/image/ocr/__init__.py
new file mode 100644
index 0000000000..37f6558d95
--- /dev/null
+++ b/xinference/model/image/ocr/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/xinference/model/image/ocr/got_ocr2.py b/xinference/model/image/ocr/got_ocr2.py
new file mode 100644
index 0000000000..803e23ed03
--- /dev/null
+++ b/xinference/model/image/ocr/got_ocr2.py
@@ -0,0 +1,76 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Optional
+
+import PIL.Image
+
+if TYPE_CHECKING:
+ from ..core import ImageModelFamilyV1
+
+logger = logging.getLogger(__name__)
+
+
+class GotOCR2Model:
+ def __init__(
+ self,
+ model_uid: str,
+ model_path: Optional[str] = None,
+ device: Optional[str] = None,
+ model_spec: Optional["ImageModelFamilyV1"] = None,
+ **kwargs,
+ ):
+ self._model_uid = model_uid
+ self._model_path = model_path
+ self._device = device
+ # model info when loading
+ self._model = None
+ self._tokenizer = None
+ # info
+ self._model_spec = model_spec
+ self._abilities = model_spec.model_ability or [] # type: ignore
+ self._kwargs = kwargs
+
+ @property
+ def model_ability(self):
+ return self._abilities
+
+ def load(self):
+ from transformers import AutoModel, AutoTokenizer
+
+ self._tokenizer = AutoTokenizer.from_pretrained(
+ self._model_path, trust_remote_code=True
+ )
+ model = AutoModel.from_pretrained(
+ self._model_path,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ device_map="cuda",
+ use_safetensors=True,
+ pad_token_id=self._tokenizer.eos_token_id,
+ )
+ self._model = model.eval().cuda()
+
+ def ocr(
+ self,
+ image: PIL.Image,
+ **kwargs,
+ ):
+ logger.info("Got OCR 2.0 kwargs: %s", kwargs)
+ if "ocr_type" not in kwargs:
+ kwargs["ocr_type"] = "ocr"
+ assert self._model is not None
+ # This chat API limits the max new tokens inside.
+ return self._model.chat(self._tokenizer, image, gradio_input=True, **kwargs)
diff --git a/xinference/model/image/scheduler/flux.py b/xinference/model/image/scheduler/flux.py
index 174acb82e3..b681e59fa7 100644
--- a/xinference/model/image/scheduler/flux.py
+++ b/xinference/model/image/scheduler/flux.py
@@ -124,7 +124,7 @@ def __init__(self):
self._running_queue: deque[Text2ImageRequest] = deque() # type: ignore
self._model = None
self._available_device = get_available_device()
- self._id_to_req: Dict[str, Text2ImageRequest] = {}
+ self._id_to_req: Dict[str, Text2ImageRequest] = {} # type: ignore
def set_model(self, model):
"""
diff --git a/xinference/model/image/stable_diffusion/core.py b/xinference/model/image/stable_diffusion/core.py
index ae9b6e4bd4..c5a9b33f86 100644
--- a/xinference/model/image/stable_diffusion/core.py
+++ b/xinference/model/image/stable_diffusion/core.py
@@ -283,9 +283,8 @@ def _load_to_device(self, model):
model.enable_sequential_cpu_offload()
elif not self._kwargs.get("device_map"):
logger.debug("Loading model to available device")
- model = move_model_to_available_device(self._model)
- # Recommended if your computer has < 64 GB of RAM
- if self._kwargs.get("attention_slicing", True):
+ model = move_model_to_available_device(model)
+ if self._kwargs.get("attention_slicing", False):
model.enable_attention_slicing()
if self._kwargs.get("vae_tiling", False):
model.enable_vae_tiling()
diff --git a/xinference/model/image/stable_diffusion/mlx.py b/xinference/model/image/stable_diffusion/mlx.py
new file mode 100644
index 0000000000..849ff62aab
--- /dev/null
+++ b/xinference/model/image/stable_diffusion/mlx.py
@@ -0,0 +1,221 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import gc
+import logging
+import re
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+from xoscar.utils import classproperty
+
+from ....types import LoRA
+from ..sdapi import SDAPIDiffusionModelMixin
+from ..utils import handle_image_result
+
+if TYPE_CHECKING:
+ from ....core.progress_tracker import Progressor
+ from ..core import ImageModelFamilyV1
+
+
+logger = logging.getLogger(__name__)
+
+
+def quantization_predicate(name: str, m) -> bool:
+ return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
+
+
+def to_latent_size(image_size: Tuple[int, int]):
+ h, w = image_size
+ h = ((h + 15) // 16) * 16
+ w = ((w + 15) // 16) * 16
+
+ if (h, w) != image_size:
+ print(
+ "Warning: The image dimensions need to be divisible by 16px. "
+ f"Changing size to {h}x{w}."
+ )
+
+ return (h // 8, w // 8)
+
+
+class MLXDiffusionModel(SDAPIDiffusionModelMixin):
+ def __init__(
+ self,
+ model_uid: str,
+ model_path: Optional[str] = None,
+ device: Optional[str] = None,
+ lora_model: Optional[List[LoRA]] = None,
+ lora_load_kwargs: Optional[Dict] = None,
+ lora_fuse_kwargs: Optional[Dict] = None,
+ model_spec: Optional["ImageModelFamilyV1"] = None,
+ **kwargs,
+ ):
+ self._model_uid = model_uid
+ self._model_path = model_path
+ self._device = device
+ # model info when loading
+ self._model = None
+ self._lora_model = lora_model
+ self._lora_load_kwargs = lora_load_kwargs or {}
+ self._lora_fuse_kwargs = lora_fuse_kwargs or {}
+ # info
+ self._model_spec = model_spec
+ self._abilities = model_spec.model_ability or [] # type: ignore
+ self._kwargs = kwargs
+
+ @property
+ def model_ability(self):
+ return self._abilities
+
+ @classproperty
+ def supported_models(self):
+ return ["FLUX.1-schnell", "FLUX.1-dev"]
+
+ def load(self):
+ try:
+ import mlx.nn as nn
+ except ImportError:
+ error_message = "Failed to import module 'mlx'"
+ installation_guide = [
+ "Please make sure 'mlx' is installed. ",
+ "You can install it by `pip install mlx`\n",
+ ]
+
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
+
+ from ....thirdparty.mlx.flux import FluxPipeline
+
+ logger.debug(
+ "Loading model from %s, kwargs: %s", self._model_path, self._kwargs
+ )
+ flux = self._model = FluxPipeline(
+ "flux-" + self._model_spec.model_name.split("-")[1],
+ model_path=self._model_path,
+ t5_padding=self._kwargs.get("t5_padding", True),
+ )
+ self._apply_lora()
+
+ quantize = self._kwargs.get("quantize", True)
+ if quantize:
+ nn.quantize(flux.flow, class_predicate=quantization_predicate)
+ nn.quantize(flux.t5, class_predicate=quantization_predicate)
+ nn.quantize(flux.clip, class_predicate=quantization_predicate)
+
+ def _apply_lora(self):
+ if self._lora_model is not None:
+ import mlx.core as mx
+
+ for lora_model in self._lora_model:
+ weights, lora_config = mx.load(
+ lora_model.local_path, return_metadata=True
+ )
+ rank = int(lora_config.get("lora_rank", 8))
+ num_blocks = int(lora_config.get("lora_blocks", -1))
+ flux = self._model
+ flux.linear_to_lora_layers(rank, num_blocks)
+ flux.flow.load_weights(list(weights.items()), strict=False)
+ flux.fuse_lora_layers()
+ logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")
+
+ @staticmethod
+ @contextlib.contextmanager
+ def _release_after():
+ import mlx.core as mx
+
+ try:
+ yield
+ finally:
+ gc.collect()
+ mx.metal.clear_cache()
+
+ def text_to_image(
+ self,
+ prompt: str,
+ n: int = 1,
+ size: str = "1024*1024",
+ response_format: str = "url",
+ **kwargs,
+ ):
+ import mlx.core as mx
+
+ flux = self._model
+ width, height = map(int, re.split(r"[^\d]+", size))
+
+ # Make the generator
+ latent_size = to_latent_size((height, width))
+ gen_latent_kwargs = {}
+ if (num_steps := kwargs.get("num_inference_steps")) is None:
+ num_steps = 50 if "dev" in self._model_spec.model_name else 2 # type: ignore
+ gen_latent_kwargs["num_steps"] = num_steps
+ if guidance := kwargs.get("guidance_scale"):
+ gen_latent_kwargs["guidance"] = guidance
+ if seed := kwargs.get("seed"):
+ gen_latent_kwargs["seed"] = seed
+
+ with self._release_after():
+ latents = flux.generate_latents( # type: ignore
+ prompt, n_images=n, latent_size=latent_size, **gen_latent_kwargs
+ )
+
+ # First we get and eval the conditioning
+ conditioning = next(latents)
+ mx.eval(conditioning)
+ peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
+ mx.metal.reset_peak_memory()
+
+ progressor: Progressor = kwargs.pop("progressor", None)
+ # Actual denoising loop
+ for i, x_t in enumerate(latents):
+ mx.eval(x_t)
+ progressor.set_progress((i + 1) / num_steps)
+
+ peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
+ mx.metal.reset_peak_memory()
+
+ # Decode them into images
+ decoded = []
+ for i in range(n):
+ decoded.append(flux.decode(x_t[i : i + 1], latent_size)) # type: ignore
+ mx.eval(decoded[-1])
+ peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
+ peak_mem_overall = max(
+ peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
+ )
+
+ images = []
+ x = mx.concatenate(decoded, axis=0)
+ x = (x * 255).astype(mx.uint8)
+ for i in range(len(x)):
+ im = Image.fromarray(np.array(x[i]))
+ images.append(im)
+
+ logger.debug(
+ f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB"
+ )
+ logger.debug(
+ f"Peak memory used for the generation: {peak_mem_generation:.3f}GB"
+ )
+ logger.debug(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
+ logger.debug(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
+
+ return handle_image_result(response_format, images)
+
+ def image_to_image(self, **kwargs):
+ raise NotImplementedError
+
+ def inpainting(self, **kwargs):
+ raise NotImplementedError
diff --git a/xinference/model/image/tests/test_got_ocr2.py b/xinference/model/image/tests/test_got_ocr2.py
new file mode 100644
index 0000000000..385fb375f5
--- /dev/null
+++ b/xinference/model/image/tests/test_got_ocr2.py
@@ -0,0 +1,41 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import io
+
+from diffusers.utils import load_image
+
+
+def test_got_ocr2(setup):
+ endpoint, _ = setup
+ from ....client import Client
+
+ client = Client(endpoint)
+
+ model_uid = client.launch_model(
+ model_uid="ocr_test",
+ model_name="GOT-OCR2_0",
+ model_type="image",
+ )
+ model = client.get_model(model_uid)
+
+ url = "https://huggingface.co/stepfun-ai/GOT-OCR2_0/resolve/main/assets/train_sample.jpg"
+ image = load_image(url)
+ bio = io.BytesIO()
+ image.save(bio, format="JPEG")
+ r = model.ocr(
+ image=bio.getvalue(),
+ ocr_type="ocr",
+ )
+ assert "Jesuits Estate" in r
diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json
index ac09f0ea5b..abdcfd4f62 100644
--- a/xinference/model/llm/llm_family.json
+++ b/xinference/model/llm/llm_family.json
@@ -8176,6 +8176,15 @@
],
"model_id": "Qwen/Qwen2.5-Coder-7B-Instruct"
},
+ {
+ "model_format": "gptq",
+ "model_size_in_billions": "7",
+ "quantizations": [
+ "Int4",
+ "Int8"
+ ],
+ "model_id": "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-{quantization}"
+ },
{
"model_format": "ggufv2",
"model_size_in_billions": "1_5",
diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json
index acf6e09fd3..7a91e561e6 100644
--- a/xinference/model/llm/llm_family_modelscope.json
+++ b/xinference/model/llm/llm_family_modelscope.json
@@ -5880,6 +5880,17 @@
"model_revision": "master",
"model_hub": "modelscope"
},
+ {
+ "model_format": "gptq",
+ "model_size_in_billions": 7,
+ "quantizations": [
+ "Int4",
+ "Int8"
+ ],
+ "model_id": "qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-{quantization}",
+ "model_revision": "master",
+ "model_hub": "modelscope"
+ },
{
"model_format": "ggufv2",
"model_size_in_billions": "1_5",
diff --git a/xinference/thirdparty/mlx/__init__.py b/xinference/thirdparty/mlx/__init__.py
new file mode 100644
index 0000000000..37f6558d95
--- /dev/null
+++ b/xinference/thirdparty/mlx/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/xinference/thirdparty/mlx/flux/__init__.py b/xinference/thirdparty/mlx/flux/__init__.py
new file mode 100644
index 0000000000..b1122d75d6
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/__init__.py
@@ -0,0 +1,15 @@
+# Copyright © 2024 Apple Inc.
+
+from .datasets import Dataset, load_dataset
+from .flux import FluxPipeline
+from .lora import LoRALinear
+from .sampler import FluxSampler
+from .trainer import Trainer
+from .utils import (
+ load_ae,
+ load_clip,
+ load_clip_tokenizer,
+ load_flow_model,
+ load_t5,
+ load_t5_tokenizer,
+)
diff --git a/xinference/thirdparty/mlx/flux/autoencoder.py b/xinference/thirdparty/mlx/flux/autoencoder.py
new file mode 100644
index 0000000000..6332bb570b
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/autoencoder.py
@@ -0,0 +1,357 @@
+# Copyright © 2024 Apple Inc.
+
+from dataclasses import dataclass
+from typing import List
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.nn.layers.upsample import upsample_nearest
+
+
+@dataclass
+class AutoEncoderParams:
+ resolution: int
+ in_channels: int
+ ch: int
+ out_ch: int
+ ch_mult: List[int]
+ num_res_blocks: int
+ z_channels: int
+ scale_factor: float
+ shift_factor: float
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = nn.GroupNorm(
+ num_groups=32,
+ dims=in_channels,
+ eps=1e-6,
+ affine=True,
+ pytorch_compatible=True,
+ )
+ self.q = nn.Linear(in_channels, in_channels)
+ self.k = nn.Linear(in_channels, in_channels)
+ self.v = nn.Linear(in_channels, in_channels)
+ self.proj_out = nn.Linear(in_channels, in_channels)
+
+ def __call__(self, x: mx.array) -> mx.array:
+ B, H, W, C = x.shape
+
+ y = x.reshape(B, 1, -1, C)
+ y = self.norm(y)
+ q = self.q(y)
+ k = self.k(y)
+ v = self.v(y)
+ y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
+ y = self.proj_out(y)
+
+ return x + y.reshape(B, H, W, C)
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = nn.GroupNorm(
+ num_groups=32,
+ dims=in_channels,
+ eps=1e-6,
+ affine=True,
+ pytorch_compatible=True,
+ )
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.norm2 = nn.GroupNorm(
+ num_groups=32,
+ dims=out_channels,
+ eps=1e-6,
+ affine=True,
+ pytorch_compatible=True,
+ )
+ self.conv2 = nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Linear(in_channels, out_channels)
+
+ def __call__(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nn.silu(h)
+ h = self.conv1(h)
+
+ h = self.norm2(h)
+ h = nn.silu(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def __call__(self, x: mx.array):
+ x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def __call__(self, x: mx.array):
+ x = upsample_nearest(x, (2, 2))
+ x = self.conv(x)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ # downsampling
+ self.conv_in = nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = []
+ block_in = self.ch
+ for i_level in range(self.num_resolutions):
+ block = []
+ attn = [] # TODO: Remove the attn, nobody appends anything to it
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ down = {}
+ down["block"] = block
+ down["attn"] = attn
+ if i_level != self.num_resolutions - 1:
+ down["downsample"] = Downsample(block_in)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = {}
+ self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid["attn_1"] = AttnBlock(block_in)
+ self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # end
+ self.norm_out = nn.GroupNorm(
+ num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
+ )
+ self.conv_out = nn.Conv2d(
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def __call__(self, x: mx.array):
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level]["block"][i_block](hs[-1])
+
+ # TODO: Remove the attn
+ if len(self.down[i_level]["attn"]) > 0:
+ h = self.down[i_level]["attn"][i_block](h)
+
+ hs.append(h)
+
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level]["downsample"](hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid["block_1"](h)
+ h = self.mid["attn_1"](h)
+ h = self.mid["block_2"](h)
+
+ # end
+ h = self.norm_out(h)
+ h = nn.silu(h)
+ h = self.conv_out(h)
+
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ ch: int,
+ out_ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ in_channels: int,
+ resolution: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.ffactor = 2 ** (self.num_resolutions - 1)
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = {}
+ self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid["attn_1"] = AttnBlock(block_in)
+ self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # upsampling
+ self.up = []
+ for i_level in reversed(range(self.num_resolutions)):
+ block = []
+ attn = [] # TODO: Remove the attn, nobody appends anything to it
+
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ up = {}
+ up["block"] = block
+ up["attn"] = attn
+ if i_level != 0:
+ up["upsample"] = Upsample(block_in)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = nn.GroupNorm(
+ num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
+ )
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def __call__(self, z: mx.array):
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid["block_1"](h)
+ h = self.mid["attn_1"](h)
+ h = self.mid["block_2"](h)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level]["block"][i_block](h)
+
+ # TODO: Remove the attn
+ if len(self.up[i_level]["attn"]) > 0:
+ h = self.up[i_level]["attn"][i_block](h)
+
+ if i_level != 0:
+ h = self.up[i_level]["upsample"](h)
+
+ # end
+ h = self.norm_out(h)
+ h = nn.silu(h)
+ h = self.conv_out(h)
+
+ return h
+
+
+class DiagonalGaussian(nn.Module):
+ def __call__(self, z: mx.array):
+ mean, logvar = mx.split(z, 2, axis=-1)
+ if self.training:
+ std = mx.exp(0.5 * logvar)
+ eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
+ return mean + std * eps
+ else:
+ return mean
+
+
+class AutoEncoder(nn.Module):
+ def __init__(self, params: AutoEncoderParams):
+ super().__init__()
+ self.encoder = Encoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.decoder = Decoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ out_ch=params.out_ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.reg = DiagonalGaussian()
+
+ self.scale_factor = params.scale_factor
+ self.shift_factor = params.shift_factor
+
+ def sanitize(self, weights):
+ new_weights = {}
+ for k, w in weights.items():
+ if w.ndim == 4:
+ w = w.transpose(0, 2, 3, 1)
+ w = w.reshape(-1).reshape(w.shape)
+ if w.shape[1:3] == (1, 1):
+ w = w.squeeze((1, 2))
+ new_weights[k] = w
+ return new_weights
+
+ def encode(self, x: mx.array):
+ z = self.reg(self.encoder(x))
+ z = self.scale_factor * (z - self.shift_factor)
+ return z
+
+ def decode(self, z: mx.array):
+ z = z / self.scale_factor + self.shift_factor
+ return self.decoder(z)
+
+ def __call__(self, x: mx.array):
+ return self.decode(self.encode(x))
diff --git a/xinference/thirdparty/mlx/flux/clip.py b/xinference/thirdparty/mlx/flux/clip.py
new file mode 100644
index 0000000000..d5a30dbf34
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/clip.py
@@ -0,0 +1,154 @@
+# Copyright © 2024 Apple Inc.
+
+from dataclasses import dataclass
+from typing import List, Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
+
+
+@dataclass
+class CLIPTextModelConfig:
+ num_layers: int = 23
+ model_dims: int = 1024
+ num_heads: int = 16
+ max_length: int = 77
+ vocab_size: int = 49408
+ hidden_act: str = "quick_gelu"
+
+ @classmethod
+ def from_dict(cls, config):
+ return cls(
+ num_layers=config["num_hidden_layers"],
+ model_dims=config["hidden_size"],
+ num_heads=config["num_attention_heads"],
+ max_length=config["max_position_embeddings"],
+ vocab_size=config["vocab_size"],
+ hidden_act=config["hidden_act"],
+ )
+
+
+@dataclass
+class CLIPOutput:
+ # The last_hidden_state indexed at the EOS token and possibly projected if
+ # the model has a projection layer
+ pooled_output: Optional[mx.array] = None
+
+ # The full sequence output of the transformer after the final layernorm
+ last_hidden_state: Optional[mx.array] = None
+
+ # A list of hidden states corresponding to the outputs of the transformer layers
+ hidden_states: Optional[List[mx.array]] = None
+
+
+class CLIPEncoderLayer(nn.Module):
+ """The transformer encoder layer from CLIP."""
+
+ def __init__(self, model_dims: int, num_heads: int, activation: str):
+ super().__init__()
+
+ self.layer_norm1 = nn.LayerNorm(model_dims)
+ self.layer_norm2 = nn.LayerNorm(model_dims)
+
+ self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
+
+ self.linear1 = nn.Linear(model_dims, 4 * model_dims)
+ self.linear2 = nn.Linear(4 * model_dims, model_dims)
+
+ self.act = _ACTIVATIONS[activation]
+
+ def __call__(self, x, attn_mask=None):
+ y = self.layer_norm1(x)
+ y = self.attention(y, y, y, attn_mask)
+ x = y + x
+
+ y = self.layer_norm2(x)
+ y = self.linear1(y)
+ y = self.act(y)
+ y = self.linear2(y)
+ x = y + x
+
+ return x
+
+
+class CLIPTextModel(nn.Module):
+ """Implements the text encoder transformer from CLIP."""
+
+ def __init__(self, config: CLIPTextModelConfig):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
+ self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
+ self.layers = [
+ CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
+ for i in range(config.num_layers)
+ ]
+ self.final_layer_norm = nn.LayerNorm(config.model_dims)
+
+ def _get_mask(self, N, dtype):
+ indices = mx.arange(N)
+ mask = indices[:, None] < indices[None]
+ mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
+ return mask
+
+ def sanitize(self, weights):
+ new_weights = {}
+ for key, w in weights.items():
+ # Remove prefixes
+ if key.startswith("text_model."):
+ key = key[11:]
+ if key.startswith("embeddings."):
+ key = key[11:]
+ if key.startswith("encoder."):
+ key = key[8:]
+
+ # Map attention layers
+ if "self_attn." in key:
+ key = key.replace("self_attn.", "attention.")
+ if "q_proj." in key:
+ key = key.replace("q_proj.", "query_proj.")
+ if "k_proj." in key:
+ key = key.replace("k_proj.", "key_proj.")
+ if "v_proj." in key:
+ key = key.replace("v_proj.", "value_proj.")
+
+ # Map ffn layers
+ if "mlp.fc1" in key:
+ key = key.replace("mlp.fc1", "linear1")
+ if "mlp.fc2" in key:
+ key = key.replace("mlp.fc2", "linear2")
+
+ new_weights[key] = w
+
+ return new_weights
+
+ def __call__(self, x):
+ # Extract some shapes
+ B, N = x.shape
+ eos_tokens = x.argmax(-1)
+
+ # Compute the embeddings
+ x = self.token_embedding(x)
+ x = x + self.position_embedding.weight[:N]
+
+ # Compute the features from the transformer
+ mask = self._get_mask(N, x.dtype)
+ hidden_states = []
+ for l in self.layers:
+ x = l(x, mask)
+ hidden_states.append(x)
+
+ # Apply the final layernorm and return
+ x = self.final_layer_norm(x)
+ last_hidden_state = x
+
+ # Select the EOS token
+ pooled_output = x[mx.arange(len(x)), eos_tokens]
+
+ return CLIPOutput(
+ pooled_output=pooled_output,
+ last_hidden_state=last_hidden_state,
+ hidden_states=hidden_states,
+ )
diff --git a/xinference/thirdparty/mlx/flux/datasets.py b/xinference/thirdparty/mlx/flux/datasets.py
new file mode 100644
index 0000000000..d31a09f179
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/datasets.py
@@ -0,0 +1,75 @@
+import json
+from pathlib import Path
+
+from PIL import Image
+
+
+class Dataset:
+ def __getitem__(self, index: int):
+ raise NotImplementedError()
+
+ def __len__(self):
+ raise NotImplementedError()
+
+
+class LocalDataset(Dataset):
+ prompt_key = "prompt"
+
+ def __init__(self, dataset: str, data_file):
+ self.dataset_base = Path(dataset)
+ with open(data_file, "r") as fid:
+ self._data = [json.loads(l) for l in fid]
+
+ def __len__(self):
+ return len(self._data)
+
+ def __getitem__(self, index: int):
+ item = self._data[index]
+ image = Image.open(self.dataset_base / item["image"])
+ return image, item[self.prompt_key]
+
+
+class LegacyDataset(LocalDataset):
+ prompt_key = "text"
+
+ def __init__(self, dataset: str):
+ self.dataset_base = Path(dataset)
+ with open(self.dataset_base / "index.json") as f:
+ self._data = json.load(f)["data"]
+
+
+class HuggingFaceDataset(Dataset):
+
+ def __init__(self, dataset: str):
+ from datasets import load_dataset as hf_load_dataset
+
+ self._df = hf_load_dataset(dataset)["train"]
+
+ def __len__(self):
+ return len(self._df)
+
+ def __getitem__(self, index: int):
+ item = self._df[index]
+ return item["image"], item["prompt"]
+
+
+def load_dataset(dataset: str):
+ dataset_base = Path(dataset)
+ data_file = dataset_base / "train.jsonl"
+ legacy_file = dataset_base / "index.json"
+
+ if data_file.exists():
+ print(f"Load the local dataset {data_file} .", flush=True)
+ dataset = LocalDataset(dataset, data_file)
+ elif legacy_file.exists():
+ print(f"Load the local dataset {legacy_file} .")
+ print()
+ print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
+ print(" See the README for details.")
+ print(flush=True)
+ dataset = LegacyDataset(dataset)
+ else:
+ print(f"Load the Hugging Face dataset {dataset} .", flush=True)
+ dataset = HuggingFaceDataset(dataset)
+
+ return dataset
diff --git a/xinference/thirdparty/mlx/flux/flux.py b/xinference/thirdparty/mlx/flux/flux.py
new file mode 100644
index 0000000000..425cb4b9ea
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/flux.py
@@ -0,0 +1,247 @@
+# Copyright © 2024 Apple Inc.
+
+from typing import Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.utils import tree_unflatten
+from tqdm import tqdm
+
+from .lora import LoRALinear
+from .sampler import FluxSampler
+from .utils import (
+ load_ae,
+ load_clip,
+ load_clip_tokenizer,
+ load_flow_model,
+ load_t5,
+ load_t5_tokenizer,
+)
+
+
+class FluxPipeline:
+ def __init__(self, name: str, model_path: str, t5_padding: bool = True):
+ self.dtype = mx.bfloat16
+ self.name = name
+ self.t5_padding = t5_padding
+
+ self.model_path = model_path
+ self.ae = load_ae(name, model_path)
+ self.flow = load_flow_model(name, model_path)
+ self.clip = load_clip(name, model_path)
+ self.clip_tokenizer = load_clip_tokenizer(name, model_path)
+ self.t5 = load_t5(name, model_path)
+ self.t5_tokenizer = load_t5_tokenizer(name, model_path)
+ self.sampler = FluxSampler(name)
+
+ def ensure_models_are_loaded(self):
+ mx.eval(
+ self.ae.parameters(),
+ self.flow.parameters(),
+ self.clip.parameters(),
+ self.t5.parameters(),
+ )
+
+ def reload_text_encoders(self):
+ self.t5 = load_t5(self.name, self.model_path)
+ self.clip = load_clip(self.name, self.model_path)
+
+ def tokenize(self, text):
+ t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
+ clip_tokens = self.clip_tokenizer.encode(text)
+ return t5_tokens, clip_tokens
+
+ def _prepare_latent_images(self, x):
+ b, h, w, c = x.shape
+
+ # Pack the latent image to 2x2 patches
+ x = x.reshape(b, h // 2, 2, w // 2, 2, c)
+ x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
+
+ # Create positions ids used to positionally encode each patch. Due to
+ # the way RoPE works, this results in an interesting positional
+ # encoding where parts of the feature are holding different positional
+ # information. Namely, the first part holds information independent of
+ # the spatial position (hence 0s), the 2nd part holds vertical spatial
+ # information and the last one horizontal.
+ i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
+ j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
+ x_ids = mx.stack([i, j, k], axis=-1)
+ x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
+
+ return x, x_ids
+
+ def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
+ # Prepare the text features
+ txt = self.t5(t5_tokens)
+ if len(txt) == 1 and n_images > 1:
+ txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
+ txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
+
+ # Prepare the clip text features
+ vec = self.clip(clip_tokens).pooled_output
+ if len(vec) == 1 and n_images > 1:
+ vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
+
+ return txt, txt_ids, vec
+
+ def _denoising_loop(
+ self,
+ x_t,
+ x_ids,
+ txt,
+ txt_ids,
+ vec,
+ num_steps: int = 35,
+ guidance: float = 4.0,
+ start: float = 1,
+ stop: float = 0,
+ ):
+ B = len(x_t)
+
+ def scalar(x):
+ return mx.full((B,), x, dtype=self.dtype)
+
+ guidance = scalar(guidance)
+ timesteps = self.sampler.timesteps(
+ num_steps,
+ x_t.shape[1],
+ start=start,
+ stop=stop,
+ )
+ for i in range(num_steps):
+ t = timesteps[i]
+ t_prev = timesteps[i + 1]
+
+ pred = self.flow(
+ img=x_t,
+ img_ids=x_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=scalar(t),
+ guidance=guidance,
+ )
+ x_t = self.sampler.step(pred, x_t, t, t_prev)
+
+ yield x_t
+
+ def generate_latents(
+ self,
+ text: str,
+ n_images: int = 1,
+ num_steps: int = 35,
+ guidance: float = 4.0,
+ latent_size: Tuple[int, int] = (64, 64),
+ seed=None,
+ ):
+ # Set the PRNG state
+ if seed is not None:
+ mx.random.seed(seed)
+
+ # Create the latent variables
+ x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
+ x_T, x_ids = self._prepare_latent_images(x_T)
+
+ # Get the conditioning
+ t5_tokens, clip_tokens = self.tokenize(text)
+ txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
+
+ # Yield the conditioning for controlled evaluation by the caller
+ yield (x_T, x_ids, txt, txt_ids, vec)
+
+ # Yield the latent sequences from the denoising loop
+ yield from self._denoising_loop(
+ x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
+ )
+
+ def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
+ h, w = latent_size
+ x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
+ x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
+ x = self.ae.decode(x)
+ return mx.clip(x + 1, 0, 2) * 0.5
+
+ def generate_images(
+ self,
+ text: str,
+ n_images: int = 1,
+ num_steps: int = 35,
+ guidance: float = 4.0,
+ latent_size: Tuple[int, int] = (64, 64),
+ seed=None,
+ reload_text_encoders: bool = True,
+ progress: bool = True,
+ ):
+ latents = self.generate_latents(
+ text, n_images, num_steps, guidance, latent_size, seed
+ )
+ mx.eval(next(latents))
+
+ if reload_text_encoders:
+ self.reload_text_encoders()
+
+ for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
+ mx.eval(x_t)
+
+ images = []
+ for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
+ images.append(self.decode(x_t[i : i + 1]))
+ mx.eval(images[-1])
+ images = mx.concatenate(images, axis=0)
+ mx.eval(images)
+
+ return images
+
+ def training_loss(
+ self,
+ x_0: mx.array,
+ t5_features: mx.array,
+ clip_features: mx.array,
+ guidance: mx.array,
+ ):
+ # Get the text conditioning
+ txt = t5_features
+ txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
+ vec = clip_features
+
+ # Prepare the latent input
+ x_0, x_ids = self._prepare_latent_images(x_0)
+
+ # Forward process
+ t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
+ eps = mx.random.normal(x_0.shape, dtype=self.dtype)
+ x_t = self.sampler.add_noise(x_0, t, noise=eps)
+ x_t = mx.stop_gradient(x_t)
+
+ # Do the denoising
+ pred = self.flow(
+ img=x_t,
+ img_ids=x_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t,
+ guidance=guidance,
+ )
+
+ return (pred + x_0 - eps).square().mean()
+
+ def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
+ """Swap the linear layers in the transformer blocks with LoRA layers."""
+ all_blocks = self.flow.double_blocks + self.flow.single_blocks
+ all_blocks.reverse()
+ num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
+ for i, block in zip(range(num_blocks), all_blocks):
+ loras = []
+ for name, module in block.named_modules():
+ if isinstance(module, nn.Linear):
+ loras.append((name, LoRALinear.from_base(module, r=rank)))
+ block.update_modules(tree_unflatten(loras))
+
+ def fuse_lora_layers(self):
+ fused_layers = []
+ for name, module in self.flow.named_modules():
+ if isinstance(module, LoRALinear):
+ fused_layers.append((name, module.fuse()))
+ self.flow.update_modules(tree_unflatten(fused_layers))
diff --git a/xinference/thirdparty/mlx/flux/layers.py b/xinference/thirdparty/mlx/flux/layers.py
new file mode 100644
index 0000000000..12397904e8
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/layers.py
@@ -0,0 +1,302 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+from dataclasses import dataclass
+from functools import partial
+from typing import List, Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+def _rope(pos: mx.array, dim: int, theta: float):
+ scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
+ omega = 1.0 / (theta**scale)
+ x = pos[..., None] * omega
+ cosx = mx.cos(x)
+ sinx = mx.sin(x)
+ pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
+ pe = pe.reshape(*pe.shape[:-1], 2, 2)
+
+ return pe
+
+
+@partial(mx.compile, shapeless=True)
+def _ab_plus_cd(a, b, c, d):
+ return a * b + c * d
+
+
+def _apply_rope(x, pe):
+ s = x.shape
+ x = x.reshape(*s[:-1], -1, 1, 2)
+ x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
+ return x.reshape(s)
+
+
+def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
+ B, H, L, D = q.shape
+
+ q = _apply_rope(q, pe)
+ k = _apply_rope(k, pe)
+ x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
+
+ return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+
+def timestep_embedding(
+ t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
+):
+ half = dim // 2
+ freqs = mx.arange(0, half, dtype=mx.float32) / half
+ freqs = freqs * (-math.log(max_period))
+ freqs = mx.exp(freqs)
+
+ x = (time_factor * t)[:, None] * freqs[None]
+ x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
+
+ return x.astype(t.dtype)
+
+
+class EmbedND(nn.Module):
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
+ super().__init__()
+
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def __call__(self, ids: mx.array):
+ n_axes = ids.shape[-1]
+ pe = mx.concatenate(
+ [_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ axis=-3,
+ )
+
+ return pe[:, None]
+
+
+class MLPEmbedder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def __call__(self, x: mx.array) -> mx.array:
+ return self.out_layer(nn.silu(self.in_layer(x)))
+
+
+class QKNorm(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.query_norm = nn.RMSNorm(dim)
+ self.key_norm = nn.RMSNorm(dim)
+
+ def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
+ return self.query_norm(q), self.key_norm(k)
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.norm = QKNorm(head_dim)
+ self.proj = nn.Linear(dim, dim)
+
+ def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
+ H = self.num_heads
+ B, L, _ = x.shape
+ qkv = self.qkv(x)
+ q, k, v = mx.split(qkv, 3, axis=-1)
+ q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ q, k = self.norm(q, k)
+ x = _attention(q, k, v, pe)
+ x = self.proj(x)
+ return x
+
+
+@dataclass
+class ModulationOut:
+ shift: mx.array
+ scale: mx.array
+ gate: mx.array
+
+
+class Modulation(nn.Module):
+ def __init__(self, dim: int, double: bool):
+ super().__init__()
+ self.is_double = double
+ self.multiplier = 6 if double else 3
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
+
+ def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
+ x = self.lin(nn.silu(x))
+ xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
+
+ mod1 = ModulationOut(*xs[:3])
+ mod2 = ModulationOut(*xs[3:]) if self.is_double else None
+
+ return mod1, mod2
+
+
+class DoubleStreamBlock(nn.Module):
+ def __init__(
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
+ ):
+ super().__init__()
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.img_mod = Modulation(hidden_size, double=True)
+ self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.img_attn = SelfAttention(
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
+ )
+
+ self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.img_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approx="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ self.txt_mod = Modulation(hidden_size, double=True)
+ self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.txt_attn = SelfAttention(
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
+ )
+
+ self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.txt_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approx="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ def __call__(
+ self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
+ ) -> Tuple[mx.array, mx.array]:
+ B, L, _ = img.shape
+ _, S, _ = txt.shape
+ H = self.num_heads
+
+ img_mod1, img_mod2 = self.img_mod(vec)
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
+
+ # prepare image for attention
+ img_modulated = self.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = self.img_attn.qkv(img_modulated)
+ img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
+ img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ img_q, img_k = self.img_attn.norm(img_q, img_k)
+
+ # prepare txt for attention
+ txt_modulated = self.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
+ txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
+ txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
+ txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
+ txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
+
+ # run actual attention
+ q = mx.concatenate([txt_q, img_q], axis=2)
+ k = mx.concatenate([txt_k, img_k], axis=2)
+ v = mx.concatenate([txt_v, img_v], axis=2)
+
+ attn = _attention(q, k, v, pe)
+ txt_attn, img_attn = mx.split(attn, [S], axis=1)
+
+ # calculate the img bloks
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
+ img = img + img_mod2.gate * self.img_mlp(
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
+ )
+
+ # calculate the txt bloks
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
+ txt = txt + txt_mod2.gate * self.txt_mlp(
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
+ )
+
+ return img, txt
+
+
+class SingleStreamBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: Optional[float] = None,
+ ):
+ super().__init__()
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ head_dim = hidden_size // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ # qkv and mlp_in
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
+ # proj and mlp_out
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
+
+ self.norm = QKNorm(head_dim)
+
+ self.hidden_size = hidden_size
+ self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+
+ self.mlp_act = nn.GELU(approx="tanh")
+ self.modulation = Modulation(hidden_size, double=False)
+
+ def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
+ B, L, _ = x.shape
+ H = self.num_heads
+
+ mod, _ = self.modulation(vec)
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
+
+ q, k, v, mlp = mx.split(
+ self.linear1(x_mod),
+ [self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
+ axis=-1,
+ )
+ q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ q, k = self.norm(q, k)
+
+ # compute attention
+ y = _attention(q, k, v, pe)
+
+ # compute activation in mlp stream, cat again and run second linear layer
+ y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
+ return x + mod.gate * y
+
+
+class LastLayer(nn.Module):
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.linear = nn.Linear(
+ hidden_size, patch_size * patch_size * out_channels, bias=True
+ )
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def __call__(self, x: mx.array, vec: mx.array):
+ shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
diff --git a/xinference/thirdparty/mlx/flux/lora.py b/xinference/thirdparty/mlx/flux/lora.py
new file mode 100644
index 0000000000..b0c8ae5605
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/lora.py
@@ -0,0 +1,76 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+class LoRALinear(nn.Module):
+ @staticmethod
+ def from_base(
+ linear: nn.Linear,
+ r: int = 8,
+ dropout: float = 0.0,
+ scale: float = 1.0,
+ ):
+ output_dims, input_dims = linear.weight.shape
+ lora_lin = LoRALinear(
+ input_dims=input_dims,
+ output_dims=output_dims,
+ r=r,
+ dropout=dropout,
+ scale=scale,
+ )
+ lora_lin.linear = linear
+ return lora_lin
+
+ def fuse(self):
+ linear = self.linear
+ bias = "bias" in linear
+ weight = linear.weight
+ dtype = weight.dtype
+
+ output_dims, input_dims = weight.shape
+ fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
+
+ lora_b = self.scale * self.lora_b.T
+ lora_a = self.lora_a.T
+ fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
+ if bias:
+ fused_linear.bias = linear.bias
+
+ return fused_linear
+
+ def __init__(
+ self,
+ input_dims: int,
+ output_dims: int,
+ r: int = 8,
+ dropout: float = 0.0,
+ scale: float = 1.0,
+ bias: bool = False,
+ ):
+ super().__init__()
+
+ # Regular linear layer weights
+ self.linear = nn.Linear(input_dims, output_dims, bias=bias)
+
+ self.dropout = nn.Dropout(p=dropout)
+
+ # Scale for low-rank update
+ self.scale = scale
+
+ # Low rank lora weights
+ scale = 1 / math.sqrt(input_dims)
+ self.lora_a = mx.random.uniform(
+ low=-scale,
+ high=scale,
+ shape=(input_dims, r),
+ )
+ self.lora_b = mx.zeros(shape=(r, output_dims))
+
+ def __call__(self, x):
+ y = self.linear(x)
+ z = (self.dropout(x) @ self.lora_a) @ self.lora_b
+ return y + (self.scale * z).astype(x.dtype)
diff --git a/xinference/thirdparty/mlx/flux/model.py b/xinference/thirdparty/mlx/flux/model.py
new file mode 100644
index 0000000000..18ea70b08a
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/model.py
@@ -0,0 +1,134 @@
+# Copyright © 2024 Apple Inc.
+
+from dataclasses import dataclass
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .layers import (
+ DoubleStreamBlock,
+ EmbedND,
+ LastLayer,
+ MLPEmbedder,
+ SingleStreamBlock,
+ timestep_embedding,
+)
+
+
+@dataclass
+class FluxParams:
+ in_channels: int
+ vec_in_dim: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list[int]
+ theta: int
+ qkv_bias: bool
+ guidance_embed: bool
+
+
+class Flux(nn.Module):
+ def __init__(self, params: FluxParams):
+ super().__init__()
+
+ self.params = params
+ self.in_channels = params.in_channels
+ self.out_channels = self.in_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
+ )
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.pe_embedder = EmbedND(
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
+ )
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
+ self.guidance_in = (
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ if params.guidance_embed
+ else nn.Identity()
+ )
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
+
+ self.double_blocks = [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ )
+ for _ in range(params.depth)
+ ]
+
+ self.single_blocks = [
+ SingleStreamBlock(
+ self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
+ )
+ for _ in range(params.depth_single_blocks)
+ ]
+
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
+
+ def sanitize(self, weights):
+ new_weights = {}
+ for k, w in weights.items():
+ if k.endswith(".scale"):
+ k = k[:-6] + ".weight"
+ for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
+ if f".{seq}." in k:
+ k = k.replace(f".{seq}.", f".{seq}.layers.")
+ break
+ new_weights[k] = w
+ return new_weights
+
+ def __call__(
+ self,
+ img: mx.array,
+ img_ids: mx.array,
+ txt: mx.array,
+ txt_ids: mx.array,
+ timesteps: mx.array,
+ y: mx.array,
+ guidance: Optional[mx.array] = None,
+ ) -> mx.array:
+ if img.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ img = self.img_in(img)
+ vec = self.time_in(timestep_embedding(timesteps, 256))
+ if self.params.guidance_embed:
+ if guidance is None:
+ raise ValueError(
+ "Didn't get guidance strength for guidance distilled model."
+ )
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
+ vec = vec + self.vector_in(y)
+ txt = self.txt_in(txt)
+
+ ids = mx.concatenate([txt_ids, img_ids], axis=1)
+ pe = self.pe_embedder(ids).astype(img.dtype)
+
+ for block in self.double_blocks:
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
+
+ img = mx.concatenate([txt, img], axis=1)
+ for block in self.single_blocks:
+ img = block(img, vec=vec, pe=pe)
+ img = img[:, txt.shape[1] :, ...]
+
+ img = self.final_layer(img, vec)
+
+ return img
diff --git a/xinference/thirdparty/mlx/flux/sampler.py b/xinference/thirdparty/mlx/flux/sampler.py
new file mode 100644
index 0000000000..3bff1ca275
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/sampler.py
@@ -0,0 +1,56 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+from functools import lru_cache
+
+import mlx.core as mx
+
+
+class FluxSampler:
+ def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
+ self._base_shift = base_shift
+ self._max_shift = max_shift
+ self._schnell = "schnell" in name
+
+ def _time_shift(self, x, t):
+ x1, x2 = 256, 4096
+ t1, t2 = self._base_shift, self._max_shift
+ exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
+ t = exp_mu / (exp_mu + (1 / t - 1))
+ return t
+
+ @lru_cache
+ def timesteps(
+ self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
+ ):
+ t = mx.linspace(start, stop, num_steps + 1)
+
+ if self._schnell:
+ t = self._time_shift(image_sequence_length, t)
+
+ return t.tolist()
+
+ def random_timesteps(self, B, L, dtype=mx.float32, key=None):
+ if self._schnell:
+ # TODO: Should we upweigh 1 and 0.75?
+ t = mx.random.randint(1, 5, shape=(B,), key=key)
+ t = t.astype(dtype) / 4
+ else:
+ t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
+ t = self._time_shift(L, t)
+
+ return t
+
+ def sample_prior(self, shape, dtype=mx.float32, key=None):
+ return mx.random.normal(shape, dtype=dtype, key=key)
+
+ def add_noise(self, x, t, noise=None, key=None):
+ noise = (
+ noise
+ if noise is not None
+ else mx.random.normal(x.shape, dtype=x.dtype, key=key)
+ )
+ return x * (1 - t) + t * noise
+
+ def step(self, pred, x_t, t, t_prev):
+ return x_t + (t_prev - t) * pred
diff --git a/xinference/thirdparty/mlx/flux/t5.py b/xinference/thirdparty/mlx/flux/t5.py
new file mode 100644
index 0000000000..cf0515cd5e
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/t5.py
@@ -0,0 +1,244 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+
+_SHARED_REPLACEMENT_PATTERNS = [
+ (".block.", ".layers."),
+ (".k.", ".key_proj."),
+ (".o.", ".out_proj."),
+ (".q.", ".query_proj."),
+ (".v.", ".value_proj."),
+ ("shared.", "wte."),
+ ("lm_head.", "lm_head.linear."),
+ (".layer.0.layer_norm.", ".ln1."),
+ (".layer.1.layer_norm.", ".ln2."),
+ (".layer.2.layer_norm.", ".ln3."),
+ (".final_layer_norm.", ".ln."),
+ (
+ "layers.0.layer.0.SelfAttention.relative_attention_bias.",
+ "relative_attention_bias.embeddings.",
+ ),
+]
+
+_ENCODER_REPLACEMENT_PATTERNS = [
+ (".layer.0.SelfAttention.", ".attention."),
+ (".layer.1.DenseReluDense.", ".dense."),
+]
+
+
+@dataclass
+class T5Config:
+ vocab_size: int
+ num_layers: int
+ num_heads: int
+ relative_attention_num_buckets: int
+ d_kv: int
+ d_model: int
+ feed_forward_proj: str
+ tie_word_embeddings: bool
+
+ d_ff: Optional[int] = None
+ num_decoder_layers: Optional[int] = None
+ relative_attention_max_distance: int = 128
+ layer_norm_epsilon: float = 1e-6
+
+ @classmethod
+ def from_dict(cls, config):
+ return cls(
+ vocab_size=config["vocab_size"],
+ num_layers=config["num_layers"],
+ num_heads=config["num_heads"],
+ relative_attention_num_buckets=config["relative_attention_num_buckets"],
+ d_kv=config["d_kv"],
+ d_model=config["d_model"],
+ feed_forward_proj=config["feed_forward_proj"],
+ tie_word_embeddings=config["tie_word_embeddings"],
+ d_ff=config.get("d_ff", 4 * config["d_model"]),
+ num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
+ relative_attention_max_distance=config.get(
+ "relative_attention_max_distance", 128
+ ),
+ layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
+ )
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, config: T5Config, bidirectional: bool):
+ self.bidirectional = bidirectional
+ self.num_buckets = config.relative_attention_num_buckets
+ self.max_distance = config.relative_attention_max_distance
+ self.n_heads = config.num_heads
+ self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
+
+ @staticmethod
+ def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
+ num_buckets = num_buckets // 2 if bidirectional else num_buckets
+ max_exact = num_buckets // 2
+
+ abspos = rpos.abs()
+ is_small = abspos < max_exact
+
+ scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
+ buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
+ buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
+
+ buckets = mx.where(is_small, abspos, buckets_large)
+ if bidirectional:
+ buckets = buckets + (rpos > 0) * num_buckets
+ else:
+ buckets = buckets * (rpos < 0)
+
+ return buckets
+
+ def __call__(self, query_length: int, key_length: int, offset: int = 0):
+ """Compute binned relative position bias"""
+ context_position = mx.arange(offset, query_length)[:, None]
+ memory_position = mx.arange(key_length)[None, :]
+
+ # shape (query_length, key_length)
+ relative_position = memory_position - context_position
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=self.bidirectional,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+
+ # shape (query_length, key_length, num_heads)
+ values = self.embeddings(relative_position_bucket)
+
+ # shape (num_heads, query_length, key_length)
+ return values.transpose(2, 0, 1)
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, config: T5Config):
+ super().__init__()
+ inner_dim = config.d_kv * config.num_heads
+ self.num_heads = config.num_heads
+ self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
+ self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
+ self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
+ self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
+
+ def __call__(
+ self,
+ queries: mx.array,
+ keys: mx.array,
+ values: mx.array,
+ mask: Optional[mx.array],
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
+ ) -> [mx.array, Tuple[mx.array, mx.array]]:
+ queries = self.query_proj(queries)
+ keys = self.key_proj(keys)
+ values = self.value_proj(values)
+
+ num_heads = self.num_heads
+ B, L, _ = queries.shape
+ _, S, _ = keys.shape
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+
+ if cache is not None:
+ key_cache, value_cache = cache
+ keys = mx.concatenate([key_cache, keys], axis=3)
+ values = mx.concatenate([value_cache, values], axis=2)
+
+ values_hat = mx.fast.scaled_dot_product_attention(
+ queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
+ )
+ values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+ return self.out_proj(values_hat), (keys, values)
+
+
+class DenseActivation(nn.Module):
+ def __init__(self, config: T5Config):
+ super().__init__()
+ mlp_dims = config.d_ff or config.d_model * 4
+ self.gated = config.feed_forward_proj.startswith("gated")
+ if self.gated:
+ self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
+ self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
+ else:
+ self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
+ self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
+ activation = config.feed_forward_proj.removeprefix("gated-")
+ if activation == "relu":
+ self.act = nn.relu
+ elif activation == "gelu":
+ self.act = nn.gelu
+ elif activation == "silu":
+ self.act = nn.silu
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ def __call__(self, x):
+ if self.gated:
+ hidden_act = self.act(self.wi_0(x))
+ hidden_linear = self.wi_1(x)
+ x = hidden_act * hidden_linear
+ else:
+ x = self.act(self.wi(x))
+ return self.wo(x)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, config: T5Config):
+ super().__init__()
+ self.attention = MultiHeadAttention(config)
+ self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dense = DenseActivation(config)
+
+ def __call__(self, x, mask):
+ y = self.ln1(x)
+ y, _ = self.attention(y, y, y, mask=mask)
+ x = x + y
+
+ y = self.ln2(x)
+ y = self.dense(y)
+ return x + y
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, config: T5Config):
+ super().__init__()
+ self.layers = [
+ TransformerEncoderLayer(config) for i in range(config.num_layers)
+ ]
+ self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
+
+ def __call__(self, x: mx.array):
+ pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
+ pos_bias = pos_bias.astype(x.dtype)
+ for layer in self.layers:
+ x = layer(x, mask=pos_bias)
+ return self.ln(x)
+
+
+class T5Encoder(nn.Module):
+ def __init__(self, config: T5Config):
+ self.wte = nn.Embedding(config.vocab_size, config.d_model)
+ self.encoder = TransformerEncoder(config)
+
+ def sanitize(self, weights):
+ new_weights = {}
+ for k, w in weights.items():
+ for old, new in _SHARED_REPLACEMENT_PATTERNS:
+ k = k.replace(old, new)
+ if k.startswith("encoder."):
+ for old, new in _ENCODER_REPLACEMENT_PATTERNS:
+ k = k.replace(old, new)
+ new_weights[k] = w
+ return new_weights
+
+ def __call__(self, inputs: mx.array):
+ return self.encoder(self.wte(inputs))
diff --git a/xinference/thirdparty/mlx/flux/tokenizers.py b/xinference/thirdparty/mlx/flux/tokenizers.py
new file mode 100644
index 0000000000..796ef3896f
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/tokenizers.py
@@ -0,0 +1,185 @@
+# Copyright © 2024 Apple Inc.
+
+import mlx.core as mx
+import regex
+from sentencepiece import SentencePieceProcessor
+
+
+class CLIPTokenizer:
+ """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
+
+ def __init__(self, bpe_ranks, vocab, max_length=77):
+ self.max_length = max_length
+ self.bpe_ranks = bpe_ranks
+ self.vocab = vocab
+ self.pat = regex.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ regex.IGNORECASE,
+ )
+
+ self._cache = {self.bos: self.bos, self.eos: self.eos}
+
+ @property
+ def bos(self):
+ return "<|startoftext|>"
+
+ @property
+ def bos_token(self):
+ return self.vocab[self.bos]
+
+ @property
+ def eos(self):
+ return "<|endoftext|>"
+
+ @property
+ def eos_token(self):
+ return self.vocab[self.eos]
+
+ def bpe(self, text):
+ if text in self._cache:
+ return self._cache[text]
+
+ unigrams = list(text[:-1]) + [text[-1] + ""]
+ unique_bigrams = set(zip(unigrams, unigrams[1:]))
+
+ if not unique_bigrams:
+ return unigrams
+
+ # In every iteration try to merge the two most likely bigrams. If none
+ # was merged we are done.
+ #
+ # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
+ while unique_bigrams:
+ bigram = min(
+ unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
+ )
+ if bigram not in self.bpe_ranks:
+ break
+
+ new_unigrams = []
+ skip = False
+ for a, b in zip(unigrams, unigrams[1:]):
+ if skip:
+ skip = False
+ continue
+
+ if (a, b) == bigram:
+ new_unigrams.append(a + b)
+ skip = True
+
+ else:
+ new_unigrams.append(a)
+
+ if not skip:
+ new_unigrams.append(b)
+
+ unigrams = new_unigrams
+ unique_bigrams = set(zip(unigrams, unigrams[1:]))
+
+ self._cache[text] = unigrams
+
+ return unigrams
+
+ def tokenize(self, text, prepend_bos=True, append_eos=True):
+ if isinstance(text, list):
+ return [self.tokenize(t, prepend_bos, append_eos) for t in text]
+
+ # Lower case cleanup and split according to self.pat. Hugging Face does
+ # a much more thorough job here but this should suffice for 95% of
+ # cases.
+ clean_text = regex.sub(r"\s+", " ", text.lower())
+ tokens = regex.findall(self.pat, clean_text)
+
+ # Split the tokens according to the byte-pair merge file
+ bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
+
+ # Map to token ids and return
+ tokens = [self.vocab[t] for t in bpe_tokens]
+ if prepend_bos:
+ tokens = [self.bos_token] + tokens
+ if append_eos:
+ tokens.append(self.eos_token)
+
+ if len(tokens) > self.max_length:
+ tokens = tokens[: self.max_length]
+ if append_eos:
+ tokens[-1] = self.eos_token
+
+ return tokens
+
+ def encode(self, text):
+ if not isinstance(text, list):
+ return self.encode([text])
+
+ tokens = self.tokenize(text)
+ length = max(len(t) for t in tokens)
+ for t in tokens:
+ t.extend([self.eos_token] * (length - len(t)))
+
+ return mx.array(tokens)
+
+
+class T5Tokenizer:
+ def __init__(self, model_file, max_length=512):
+ self._tokenizer = SentencePieceProcessor(model_file)
+ self.max_length = max_length
+
+ @property
+ def pad(self):
+ try:
+ return self._tokenizer.id_to_piece(self.pad_token)
+ except IndexError:
+ return None
+
+ @property
+ def pad_token(self):
+ return self._tokenizer.pad_id()
+
+ @property
+ def bos(self):
+ try:
+ return self._tokenizer.id_to_piece(self.bos_token)
+ except IndexError:
+ return None
+
+ @property
+ def bos_token(self):
+ return self._tokenizer.bos_id()
+
+ @property
+ def eos(self):
+ try:
+ return self._tokenizer.id_to_piece(self.eos_token)
+ except IndexError:
+ return None
+
+ @property
+ def eos_token(self):
+ return self._tokenizer.eos_id()
+
+ def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
+ if isinstance(text, list):
+ return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
+
+ tokens = self._tokenizer.encode(text)
+
+ if prepend_bos and self.bos_token >= 0:
+ tokens = [self.bos_token] + tokens
+ if append_eos and self.eos_token >= 0:
+ tokens.append(self.eos_token)
+ if pad and len(tokens) < self.max_length and self.pad_token >= 0:
+ tokens += [self.pad_token] * (self.max_length - len(tokens))
+
+ return tokens
+
+ def encode(self, text, pad=True):
+ if not isinstance(text, list):
+ return self.encode([text], pad=pad)
+
+ pad_token = self.pad_token if self.pad_token >= 0 else 0
+ tokens = self.tokenize(text, pad=pad)
+ length = max(len(t) for t in tokens)
+ for t in tokens:
+ t.extend([pad_token] * (length - len(t)))
+
+ return mx.array(tokens)
diff --git a/xinference/thirdparty/mlx/flux/trainer.py b/xinference/thirdparty/mlx/flux/trainer.py
new file mode 100644
index 0000000000..40a126e886
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/trainer.py
@@ -0,0 +1,98 @@
+import mlx.core as mx
+import numpy as np
+from PIL import Image, ImageFile
+from tqdm import tqdm
+
+from .datasets import Dataset
+from .flux import FluxPipeline
+
+
+class Trainer:
+
+ def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
+ self.flux = flux
+ self.dataset = dataset
+ self.args = args
+ self.latents = []
+ self.t5_features = []
+ self.clip_features = []
+
+ def _random_crop_resize(self, img):
+ resolution = self.args.resolution
+ width, height = img.size
+
+ a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
+
+ # Random crop the input image between 0.8 to 1.0 of its original dimensions
+ crop_size = (
+ max((0.8 + 0.2 * a) * width, resolution[0]),
+ max((0.8 + 0.2 * b) * height, resolution[1]),
+ )
+ pan = (width - crop_size[0], height - crop_size[1])
+ img = img.crop(
+ (
+ pan[0] * c,
+ pan[1] * d,
+ crop_size[0] + pan[0] * c,
+ crop_size[1] + pan[1] * d,
+ )
+ )
+
+ # Fit the largest rectangle with the ratio of resolution in the image
+ # rectangle.
+ width, height = crop_size
+ ratio = resolution[0] / resolution[1]
+ r1 = (height * ratio, height)
+ r2 = (width, width / ratio)
+ r = r1 if r1[0] <= width else r2
+ img = img.crop(
+ (
+ (width - r[0]) / 2,
+ (height - r[1]) / 2,
+ (width + r[0]) / 2,
+ (height + r[1]) / 2,
+ )
+ )
+
+ # Finally resize the image to resolution
+ img = img.resize(resolution, Image.LANCZOS)
+
+ return mx.array(np.array(img))
+
+ def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
+ for i in range(num_augmentations):
+ img = self._random_crop_resize(input_img)
+ img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
+ x_0 = self.flux.ae.encode(img[None])
+ x_0 = x_0.astype(self.flux.dtype)
+ mx.eval(x_0)
+ self.latents.append(x_0)
+
+ def _encode_prompt(self, prompt):
+ t5_tok, clip_tok = self.flux.tokenize([prompt])
+ t5_feat = self.flux.t5(t5_tok)
+ clip_feat = self.flux.clip(clip_tok).pooled_output
+ mx.eval(t5_feat, clip_feat)
+ self.t5_features.append(t5_feat)
+ self.clip_features.append(clip_feat)
+
+ def encode_dataset(self):
+ """Encode the images & prompt in the latent space to prepare for training."""
+ self.flux.ae.eval()
+ for image, prompt in tqdm(self.dataset, desc="encode dataset"):
+ self._encode_image(image, self.args.num_augmentations)
+ self._encode_prompt(prompt)
+
+ def iterate(self, batch_size):
+ xs = mx.concatenate(self.latents)
+ t5 = mx.concatenate(self.t5_features)
+ clip = mx.concatenate(self.clip_features)
+ mx.eval(xs, t5, clip)
+ n_aug = self.args.num_augmentations
+ while True:
+ x_indices = mx.random.permutation(len(self.latents))
+ c_indices = x_indices // n_aug
+ for i in range(0, len(self.latents), batch_size):
+ x_i = x_indices[i : i + batch_size]
+ c_i = c_indices[i : i + batch_size]
+ yield xs[x_i], t5[c_i], clip[c_i]
diff --git a/xinference/thirdparty/mlx/flux/utils.py b/xinference/thirdparty/mlx/flux/utils.py
new file mode 100644
index 0000000000..47e7fe9e33
--- /dev/null
+++ b/xinference/thirdparty/mlx/flux/utils.py
@@ -0,0 +1,179 @@
+# Copyright © 2024 Apple Inc.
+
+import json
+import os
+from dataclasses import dataclass
+from typing import Optional
+
+import mlx.core as mx
+
+from .autoencoder import AutoEncoder, AutoEncoderParams
+from .clip import CLIPTextModel, CLIPTextModelConfig
+from .model import Flux, FluxParams
+from .t5 import T5Config, T5Encoder
+from .tokenizers import CLIPTokenizer, T5Tokenizer
+
+
+@dataclass
+class ModelSpec:
+ params: FluxParams
+ ae_params: AutoEncoderParams
+ ckpt_path: Optional[str]
+ ae_path: Optional[str]
+ repo_id: Optional[str]
+ repo_flow: Optional[str]
+ repo_ae: Optional[str]
+
+
+configs = {
+ "flux-dev": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux1-dev.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_DEV"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+ "flux-schnell": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-schnell",
+ repo_flow="flux1-schnell.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=False,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+}
+
+
+def load_flow_model(name: str, ckpt_path: str):
+ # Make the model
+ model = Flux(configs[name].params)
+
+ # Load the checkpoint if needed
+ if os.path.isdir(ckpt_path):
+ ckpt_path = os.path.join(ckpt_path, configs[name].repo_flow)
+ weights = mx.load(ckpt_path)
+ weights = model.sanitize(weights)
+ model.load_weights(list(weights.items()))
+
+ return model
+
+
+def load_ae(name: str, ckpt_path: str):
+ # Make the autoencoder
+ ae = AutoEncoder(configs[name].ae_params)
+
+ # Load the checkpoint if needed
+ ckpt_path = os.path.join(ckpt_path, "ae.safetensors")
+ weights = mx.load(ckpt_path)
+ weights = ae.sanitize(weights)
+ ae.load_weights(list(weights.items()))
+
+ return ae
+
+
+def load_clip(name: str, ckpt_path: str):
+ config_path = os.path.join(ckpt_path, "text_encoder/config.json")
+ with open(config_path) as f:
+ config = CLIPTextModelConfig.from_dict(json.load(f))
+
+ # Make the clip text encoder
+ clip = CLIPTextModel(config)
+
+ ckpt_path = os.path.join(ckpt_path, "text_encoder/model.safetensors")
+ weights = mx.load(ckpt_path)
+ weights = clip.sanitize(weights)
+ clip.load_weights(list(weights.items()))
+
+ return clip
+
+
+def load_t5(name: str, ckpt_path: str):
+ config_path = os.path.join(ckpt_path, "text_encoder_2/config.json")
+ with open(config_path) as f:
+ config = T5Config.from_dict(json.load(f))
+
+ # Make the T5 model
+ t5 = T5Encoder(config)
+
+ model_index = os.path.join(ckpt_path, "text_encoder_2/model.safetensors.index.json")
+ weight_files = set()
+ with open(model_index) as f:
+ for _, w in json.load(f)["weight_map"].items():
+ weight_files.add(w)
+ weights = {}
+ for w in weight_files:
+ w = f"text_encoder_2/{w}"
+ w = os.path.join(ckpt_path, w)
+ weights.update(mx.load(w))
+ weights = t5.sanitize(weights)
+ t5.load_weights(list(weights.items()))
+
+ return t5
+
+
+def load_clip_tokenizer(name: str, ckpt_path: str):
+ vocab_file = os.path.join(ckpt_path, "tokenizer/vocab.json")
+ with open(vocab_file, encoding="utf-8") as f:
+ vocab = json.load(f)
+
+ merges_file = os.path.join(ckpt_path, "tokenizer/merges.txt")
+ with open(merges_file, encoding="utf-8") as f:
+ bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
+ bpe_merges = [tuple(m.split()) for m in bpe_merges]
+ bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
+
+ return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
+
+
+def load_t5_tokenizer(name: str, ckpt_path: str, pad: bool = True):
+ model_file = os.path.join(ckpt_path, "tokenizer_2/spiece.model")
+ return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
diff --git a/xinference/web/ui/src/scenes/launch_model/modelCard.js b/xinference/web/ui/src/scenes/launch_model/modelCard.js
index a4391f830a..bcffbace11 100644
--- a/xinference/web/ui/src/scenes/launch_model/modelCard.js
+++ b/xinference/web/ui/src/scenes/launch_model/modelCard.js
@@ -327,6 +327,9 @@ const ModelCard = ({
modelDataWithID_LLM.n_gpu_layers = nGPULayers
}
+ const modelDataWithID =
+ modelType === 'LLM' ? modelDataWithID_LLM : modelDataWithID_other
+
if (
loraListArr.length ||
imageLoraLoadKwargsArr.length ||
@@ -354,12 +357,9 @@ const ModelCard = ({
})
peft_model_config['lora_list'] = lora_list
}
- modelDataWithID_LLM['peft_model_config'] = peft_model_config
+ modelDataWithID['peft_model_config'] = peft_model_config
}
- const modelDataWithID =
- modelType === 'LLM' ? modelDataWithID_LLM : modelDataWithID_other
-
if (customParametersArr.length) {
customParametersArr.forEach((item) => {
modelDataWithID[item.key] = handleValueType(item.value)
@@ -376,10 +376,15 @@ const ModelCard = ({
`/running_models/${modelType}`
)
let historyArr = JSON.parse(localStorage.getItem('historyArr')) || []
- if (!historyArr.some((item) => deepEqual(item, modelDataWithID))) {
- historyArr = historyArr.filter(
- (item) => item.model_name !== modelDataWithID.model_name
- )
+ const historyModelNameArr = historyArr.map((item) => item.model_name)
+ if (historyModelNameArr.includes(modelDataWithID.model_name)) {
+ historyArr = historyArr.map((item) => {
+ if (item.model_name === modelDataWithID.model_name) {
+ return modelDataWithID
+ }
+ return item
+ })
+ } else {
historyArr.push(modelDataWithID)
}
localStorage.setItem('historyArr', JSON.stringify(historyArr))
@@ -600,28 +605,10 @@ const ModelCard = ({
})
setLoraArr(loraData)
- let ImageLoraLoadData = []
- for (let key in peft_model_config?.image_lora_load_kwargs) {
- ImageLoraLoadData.push({
- key: key,
- value: peft_model_config?.image_lora_load_kwargs[key],
- })
- }
- setImageLoraLoadArr(ImageLoraLoadData)
-
- let ImageLoraFuseData = []
- for (let key in peft_model_config?.image_lora_fuse_kwargs) {
- ImageLoraFuseData.push({
- key: key,
- value: peft_model_config?.image_lora_fuse_kwargs[key],
- })
- }
- setImageLoraFuseArr(ImageLoraFuseData)
-
let customData = []
for (let key in arr[0]) {
!llmAllDataKey.includes(key) &&
- customData.push({ key: key, value: arr[0][key] })
+ customData.push({ key: key, value: arr[0][key] || 'none' })
}
setCustomArr(customData)
@@ -635,11 +622,7 @@ const ModelCard = ({
)
setIsOther(true)
- if (
- loraData.length ||
- ImageLoraLoadData.length ||
- ImageLoraFuseData.length
- ) {
+ if (loraData.length) {
setIsOther(true)
setIsPeftModelConfig(true)
}
@@ -657,37 +640,54 @@ const ModelCard = ({
setDownloadHub(arr[0].download_hub || '')
setModelPath(arr[0].model_path || '')
+ if (arr[0].model_type === 'image') {
+ let loraData = []
+ arr[0].peft_model_config?.lora_list?.forEach((item) => {
+ loraData.push({
+ lora_name: item.lora_name,
+ local_path: item.local_path,
+ })
+ })
+ setLoraArr(loraData)
+
+ let ImageLoraLoadData = []
+ for (let key in arr[0].peft_model_config?.image_lora_load_kwargs) {
+ ImageLoraLoadData.push({
+ key: key,
+ value:
+ arr[0].peft_model_config?.image_lora_load_kwargs[key] || 'none',
+ })
+ }
+ setImageLoraLoadArr(ImageLoraLoadData)
+
+ let ImageLoraFuseData = []
+ for (let key in arr[0].peft_model_config?.image_lora_fuse_kwargs) {
+ ImageLoraFuseData.push({
+ key: key,
+ value:
+ arr[0].peft_model_config?.image_lora_fuse_kwargs[key] || 'none',
+ })
+ }
+ setImageLoraFuseArr(ImageLoraFuseData)
+
+ if (
+ loraData.length ||
+ ImageLoraLoadData.length ||
+ ImageLoraFuseData.length
+ ) {
+ setIsPeftModelConfig(true)
+ }
+ }
+
let customData = []
for (let key in arr[0]) {
!llmAllDataKey.includes(key) &&
- customData.push({ key: key, value: arr[0][key] })
+ customData.push({ key: key, value: arr[0][key] || 'none' })
}
setCustomArr(customData)
}
}
- const deepEqual = (obj1, obj2) => {
- if (obj1 === obj2) return true
- if (
- typeof obj1 !== 'object' ||
- typeof obj2 !== 'object' ||
- obj1 == null ||
- obj2 == null
- ) {
- return false
- }
-
- let keysA = Object.keys(obj1)
- let keysB = Object.keys(obj2)
- if (keysA.length !== keysB.length) return false
- for (let key of keysA) {
- if (!keysB.includes(key) || !deepEqual(obj1[key], obj2[key])) {
- return false
- }
- }
- return true
- }
-
const handleCollection = (bool) => {
setHover(false)
@@ -725,8 +725,6 @@ const ModelCard = ({
setDownloadHub('')
setModelPath('')
setLoraArr([])
- setImageLoraLoadArr([])
- setImageLoraFuseArr([])
setCustomArr([])
setIsOther(false)
setIsPeftModelConfig(false)
@@ -738,6 +736,11 @@ const ModelCard = ({
setWorkerIp('')
setDownloadHub('')
setModelPath('')
+ setLoraArr([])
+ setImageLoraLoadArr([])
+ setImageLoraFuseArr([])
+ setCustomArr([])
+ setIsPeftModelConfig(false)
}
}
@@ -991,7 +994,14 @@ const ModelCard = ({
{(() => {
if (modelData.language) {
return modelData.language.map((v) => {
- return
+ return (
+
+ )
})
} else if (modelData.model_family) {
return (
@@ -1446,30 +1456,6 @@ const ModelCard = ({
onJudgeArr={judgeArr}
pairData={loraArr}
/>
- {
- setImageLoraLoadKwargsArr(arr)
- }}
- onJudgeArr={judgeArr}
- pairData={imageLoraLoadArr}
- />
- {
- setImageLoraFuseKwargsArr(arr)
- }}
- onJudgeArr={judgeArr}
- pairData={imageLoraFuseArr}
- />
) : (
-
- setModelUID(e.target.value)}
- />
- setReplica(parseInt(e.target.value, 10))}
- />
+
- Device
-
-
- {nGpu === 'GPU' && (
+ setModelUID(e.target.value)}
+ />
+ setReplica(parseInt(e.target.value, 10))}
+ />
+
+ Device
+
+
+ {nGpu === 'GPU' && (
+
+ {
+ setGPUIdxAlert(false)
+ setGPUIdx(e.target.value)
+ const regular = /^\d+(?:,\d+)*$/
+ if (
+ e.target.value !== '' &&
+ !regular.test(e.target.value)
+ ) {
+ setGPUIdxAlert(true)
+ }
+ }}
+ />
+ {GPUIdxAlert && (
+
+ Please enter numeric data separated by commas, for
+ example: 0,1,2
+
+ )}
+
+ )}
setWorkerIp(e.target.value)}
+ />
+
+
+
+ (Optional) Download_hub
+
+
+
+
+ setModelPath(e.target.value)}
/>
- {GPUIdxAlert && (
-
- Please enter numeric data separated by commas, for
- example: 0,1,2
-
- )}
- )}
-
- setWorkerIp(e.target.value)}
- />
-
-
-
- (Optional) Download_hub
-
-
-
-
- setModelPath(e.target.value)}
+ onGetArr={(arr) => {
+ setCustomParametersArr(arr)
+ }}
+ onJudgeArr={judgeArr}
+ pairData={customArr}
/>
- {
- setCustomParametersArr(arr)
- }}
- onJudgeArr={judgeArr}
- pairData={customArr}
- />
-
+
)}