diff --git a/.github/workflows/build-embedding.yml b/.github/workflows/build-embedding.yml deleted file mode 100644 index 0c460d96b..000000000 --- a/.github/workflows/build-embedding.yml +++ /dev/null @@ -1,239 +0,0 @@ -name: Build Embedding Bentos -on: - workflow_dispatch: - push: - branches: - - 'main' - tags: - - '*' - paths: - - '.github/workflows/build-embedding.yml' - - 'openllm-python/src/openllm/**' - - 'openllm-core/src/openllm_core/**' - - 'openllm-client/src/openllm_client/**' - pull_request: - branches: - - 'main' - paths: - - '.github/workflows/build-embedding.yml' - - 'openllm-python/src/openllm/**' - - 'openllm-core/src/openllm_core/**' - - 'openllm-client/src/openllm_client/**' - types: [labeled, opened, synchronize, reopened] -# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun -defaults: - run: - shell: bash --noprofile --norc -exo pipefail {0} -env: - LINES: 120 - COLUMNS: 120 - AWS_REGION: us-east-1 - OPENLLM_OPT_MODEL_ID: facebook/opt-125m - BENTOML_HOME: ${{ github.workspace }}/bentoml - OPENLLM_DEV_BUILD: True - OPENLLM_DO_NOT_TRACK: True -concurrency: - group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true -jobs: - get_commit_message: - name: Get commit message - runs-on: ubuntu-latest - if: "github.repository == 'bentoml/OpenLLM'" # Don't run on fork repository - outputs: - message: ${{ steps.commit_message.outputs.message }} - steps: - - uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # ratchet:actions/checkout@v4.1.0 - # Gets the correct commit message for pull request - with: - ref: ${{ github.event.pull_request.head.sha }} - - name: Get commit message - id: commit_message - run: | - set -xe - COMMIT_MSG=$(git log --no-merges -1 --oneline) - echo "message=$COMMIT_MSG" >> $GITHUB_OUTPUT - echo github.ref ${{ github.ref }} - start-runner: - name: Start self-hosted EC2 runner - runs-on: ubuntu-latest - needs: get_commit_message - if: >- - contains(needs.get_commit_message.outputs.message, '[ec2 build]') || github.event_name == 'workflow_dispatch' || (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, '00 - EC2 Build')) || (github.event_name == 'push' && (startsWith(github.ref, 'refs/tags/v') || startsWith(github.ref, 'refs/heads/main'))) - env: - EC2_INSTANCE_TYPE: t3.2xlarge - EC2_AMI_ID: ami-0fc9d48803f691665 - EC2_SUBNET_ID: subnet-0f3cfaf555c0fe5d7,subnet-03c02763156f1c011,subnet-01e191856710e5205,subnet-06caca1b04878bf17,subnet-0ec43be52d7ca5619,subnet-0f23c41d786013d15 - EC2_SECURITY_GROUP: sg-0b84a8e57c4524eb9 - outputs: - label: ${{ steps.start-ec2-runner.outputs.label }} - ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@010d0da01d0b5a38af31e9c3470dbfdabdecca3a # ratchet:aws-actions/configure-aws-credentials@v4.0.1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Start EC2 Runner - id: start-ec2-runner - uses: aarnphm/ec2-github-runner@main # ratchet:exclude - with: - mode: start - github-token: ${{ secrets.OPENLLM_PAT }} - ec2-region: ${{ env.AWS_REGION }} - ec2-image-id: ${{ env.EC2_AMI_ID }} - ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} - subnet-id: ${{ env.EC2_SUBNET_ID }} - security-group-id: ${{ env.EC2_SECURITY_GROUP }} - build-and-push-embedding-bento: - name: Build embedding container - needs: start-runner - runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner - permissions: - contents: write - packages: write - # This is used to complete the identity challenge - # with sigstore/fulcio when running outside of PRs. - id-token: write - security-events: write - steps: - - uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # ratchet:actions/checkout@v4.1.0 - with: - fetch-depth: 0 - - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1 - with: - bentoml-version: 'main' - python-version: '3.11' - - name: Inject slug/short variables - uses: rlespinasse/github-slug-action@102b1a064a9b145e56556e22b18b19c624538d94 # ratchet:rlespinasse/github-slug-action@v4.4.1 - - name: Set up QEMU - uses: docker/setup-qemu-action@68827325e0b33c7199eb31dd4e31fbe9023e06e3 # ratchet:docker/setup-qemu-action@v3.0.0 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@f95db51fddba0c2d1ec667646a06c2ce06100226 # ratchet:docker/setup-buildx-action@v3.0.0 - with: - install: true - driver-opts: | - image=moby/buildkit:master - network=host - - name: Install cosign - if: github.event_name != 'pull_request' - uses: sigstore/cosign-installer@11086d25041f77fe8fe7b9ea4e48e3b9192b8f19 # ratchet:sigstore/cosign-installer@v3.1.2 - with: - cosign-release: 'v2.1.1' - - name: Login to GitHub Container Registry - uses: docker/login-action@343f7c4344506bcbf9b4de18042ae17996df046d # ratchet:docker/login-action@v3.0.0 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Extract metadata tags and labels on PRs - if: github.event_name == 'pull_request' - id: meta-pr - uses: docker/metadata-action@96383f45573cb7f253c731d3b3ab81c87ef81934 # ratchet:docker/metadata-action@v5.0.0 - with: - images: ghcr.io/bentoml/openllm-embedding - tags: | - type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} - - name: Extract metadata tags and labels for main, release or tag - if: github.event_name != 'pull_request' - id: meta - uses: docker/metadata-action@96383f45573cb7f253c731d3b3ab81c87ef81934 # ratchet:docker/metadata-action@v5.0.0 - with: - flavor: latest=auto - images: ghcr.io/bentoml/openllm-embedding - tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} - type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} - - name: Build OPT Bento with base embeddings - id: bento-tag - run: | - bash local.sh - pip install 'build[virtualenv]==0.10.0' - openllm build opt --serialisation legacy --bento-version sha-${{ env.GITHUB_SHA_SHORT }} --machine --dockerfile-template - < /etc/apt/apt.conf.d/keep-cache - {% call common.RUN(__enable_buildkit__) -%} {{ common.mount_cache(__lib_apt__) }} {{ common.mount_cache(__cache_apt__) }} {% endcall -%} set -eux && \ - apt-get update -y && \ - apt-get install -q -y --no-install-recommends --allow-remove-essential \ - ca-certificates gnupg2 bash build-essential {% if __options__system_packages is not none %}{{ __options__system_packages | join(' ') }}{% endif -%} - {% endblock %} - EOF - bento_tag=$(python -c "import openllm;print(str(openllm.build('opt',bento_version='sha-${{ env.GITHUB_SHA_SHORT }}',serialisation='legacy').tag))") - echo "tag=$bento_tag" >> $GITHUB_OUTPUT - - name: Build and push Embedding Bento - id: build-and-push - uses: bentoml/containerize-push-action@main # ratchet:exclude - with: - bento-tag: ${{ steps.bento-tag.outputs.tag }} - platforms: linux/amd64 - push: true - build-args: | - GIT_SHA=${{ env.GITHUB_SHA }} - DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} - tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} - labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} - - name: Sign the released image - if: ${{ github.event_name != 'pull_request' }} - env: - COSIGN_EXPERIMENTAL: 'true' - run: echo "${{ steps.meta.outputs.tags }}" | xargs -I {} cosign sign --yes {}@${{ steps.build-and-push.outputs.digest }} - - name: Run Trivy in GitHub SBOM mode and submit results to Dependency Graph - uses: aquasecurity/trivy-action@fbd16365eb88e12433951383f5e99bd901fc618f # ratchet:aquasecurity/trivy-action@master - if: ${{ github.event_name != 'pull_request' }} - with: - image-ref: 'ghcr.io/bentoml/openllm-embedding:sha-${{ env.GITHUB_SHA_SHORT }}' - format: 'github' - output: 'dependency-results.sbom.json' - github-pat: ${{ secrets.UI_GITHUB_TOKEN }} - scanners: 'vuln' - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@fbd16365eb88e12433951383f5e99bd901fc618f # ratchet:aquasecurity/trivy-action@master - if: ${{ github.event_name != 'pull_request' }} - with: - image-ref: 'ghcr.io/bentoml/openllm-embedding:sha-${{ env.GITHUB_SHA_SHORT }}' - format: 'sarif' - output: 'trivy-results.sarif' - severity: 'CRITICAL' - scanners: 'vuln' - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@d90b8d79de6dc1f58e83a1499aa58d6c93dc28de # ratchet:github/codeql-action/upload-sarif@v2.22.2 - if: ${{ github.event_name != 'pull_request' }} - with: - sarif_file: 'trivy-results.sarif' - stop-runner: - name: Stop self-hosted EC2 runner - needs: - - start-runner - - build-and-push-embedding-bento - - get_commit_message - runs-on: ubuntu-latest - if: >- - (contains(needs.get_commit_message.outputs.message, '[ec2 build]') || github.event_name == 'workflow_dispatch' || (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, '00 - EC2 Build')) || (github.event_name == 'push' && (startsWith(github.ref, 'refs/tags/v') || startsWith(github.ref, 'refs/heads/main')))) && always() - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@010d0da01d0b5a38af31e9c3470dbfdabdecca3a # ratchet:aws-actions/configure-aws-credentials@v4.0.1 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: ${{ env.AWS_REGION }} - - name: Stop EC2 runner - uses: aarnphm/ec2-github-runner@af796d217e24ecbbc5a2c49e780cd90616e2b962 # ratchet:aarnphm/ec2-github-runner@main - with: - mode: stop - github-token: ${{ secrets.OPENLLM_PAT }} - ec2-region: ${{ env.AWS_REGION }} - label: ${{ needs.start-runner.outputs.label }} - ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cea6976ef..914c816da 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,8 +14,7 @@ repos: verbose: true exclude: | (?x)^( - openllm-client/src/openllm_client/pb.*| - openllm-python/src/openllm/cli/entrypoint.py + openllm-client/src/openllm_client/pb.* )$ - repo: https://github.com/astral-sh/ruff-pre-commit rev: 'v0.0.292' diff --git a/README.md b/README.md index 99b69502e..d9e254404 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,6 @@ Options: Commands: build Package a given models into a Bento. - embed Get embeddings interactively, from a terminal. import Setup LLM interactively. instruct Instruct agents interactively for given tasks, from a... models List all supported models. @@ -867,47 +866,6 @@ openllm build opt --adapter-id ./path/to/adapter_id --build-ctx . > We will gradually roll out support for fine-tuning all models. > Currently, the models supporting fine-tuning with OpenLLM include: OPT, Falcon, and LlaMA. -## 🧮 Embeddings - -OpenLLM provides embeddings endpoint for embeddings calculation. This can -be accessed via `/v1/embeddings`. - -To use via CLI, simply call `openllm embed`: - -```bash -openllm embed --endpoint http://localhost:3000 "I like to eat apples" -o json -{ - "embeddings": [ - 0.006569798570126295, - -0.031249752268195152, - -0.008072729222476482, - 0.00847396720200777, - -0.005293501541018486, - ...... - -0.002078012563288212, - -0.00676426338031888, - -0.002022686880081892 - ], - "num_tokens": 9 -} -``` - -To invoke this endpoint, use `client.embed` from the Python SDK: - -```python -import openllm - -client = openllm.client.HTTPClient("http://localhost:3000") - -client.embed("I like to eat apples") -``` - -> [!NOTE] -> Currently, the following model family supports embeddings calculation: Llama, T5 (Flan-T5, FastChat, etc.), ChatGLM -> For the remaining LLM that doesn't have specific embedding implementation, -> we will use a generic [BertModel](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) -> for embeddings generation. The implementation is largely based on [`bentoml/sentence-embedding-bento`](https://github.com/bentoml/sentence-embedding-bento) - ## 🥅 Playground and Chat UI The following UIs are currently available for OpenLLM: diff --git a/changelog.d/500.breaking.md b/changelog.d/500.breaking.md new file mode 100644 index 000000000..5de957992 --- /dev/null +++ b/changelog.d/500.breaking.md @@ -0,0 +1,5 @@ +Remove embeddings endpoints from the provided API, as I think it is probably not a good fit to have them here, yet. + +This means that `openllm embed` will also be removed. + +Client implementation is also updated to fix 0.3.7 breaking changes with models other than Llama diff --git a/openllm-client/src/openllm_client/_http.py b/openllm-client/src/openllm_client/_http.py index 4e6f425ac..e0b4278e4 100644 --- a/openllm-client/src/openllm_client/_http.py +++ b/openllm-client/src/openllm_client/_http.py @@ -1,5 +1,11 @@ from __future__ import annotations +import asyncio +import enum +import logging +import os +import time import typing as t +import urllib.error from urllib.parse import urlparse @@ -11,22 +17,65 @@ from ._schemas import Response from ._schemas import StreamResponse +logger = logging.getLogger(__name__) + def _address_validator(_: t.Any, attr: attr.Attribute[t.Any], value: str) -> None: if not isinstance(value, str): raise TypeError(f'{attr.name} must be a string') if not urlparse(value).netloc: raise ValueError(f'{attr.name} must be a valid URL') -@attr.define -class HTTPClient: - address: str = attr.field(validator=_address_validator, converter=lambda addr: addr if '://' in addr else 'http://' + addr) - api_version: str = 'v1' - timeout: int = 30 - client_args: t.Dict[str, t.Any] = attr.field(factory=dict) - __metadata: dict[str, t.Any] = attr.field(default=None) - __config: dict[str, t.Any] = attr.field(default=None) - _inner: httpx.Client = attr.field(init=False, repr=False) +def _address_converter(addr: str) -> str: + return addr if '://' in addr else 'http://' + addr + +class ServerState(enum.Enum): + # CLOSED: The server is not yet ready or `wait_until_server_ready` has not been called/failed. + CLOSED = 1 + # READY: The server is ready and `wait_until_server_ready` has been called. + READY = 2 - def __attrs_post_init__(self) -> None: - self._inner = httpx.Client(base_url=self.address, timeout=self.timeout, **self.client_args) +_object_setattr = object.__setattr__ + +@attr.define(init=False) +class HTTPClient: + address: str = attr.field(validator=_address_validator, converter=_address_converter) + client_args: t.Dict[str, t.Any] = attr.field() + _inner: httpx.Client = attr.field(repr=False) + + _timeout: int = attr.field(default=30, repr=False) + _api_version: str = attr.field(default='v1', repr=False) + _state: ServerState = attr.field(default=ServerState.CLOSED, repr=False) + + __metadata: dict[str, t.Any] | None = attr.field(default=None, repr=False) + __config: dict[str, t.Any] | None = attr.field(default=None, repr=False) + + @staticmethod + def wait_until_server_ready(addr: str, timeout: float = 30, check_interval: int = 1, **client_args: t.Any) -> None: + addr = _address_converter(addr) + logger.debug('Wait for server @ %s to be ready', addr) + start = time.monotonic() + while time.monotonic() - start < timeout: + try: + with httpx.Client(base_url=addr, **client_args) as sess: + status = sess.get('/readyz').status_code + if status == 200: break + else: time.sleep(check_interval) + except (httpx.ConnectError, urllib.error.URLError, ConnectionError): + logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval) + time.sleep(check_interval) + # Try once more and raise for exception + try: + with httpx.Client(base_url=addr, **client_args) as sess: + status = sess.get('/readyz').status_code + except httpx.HTTPStatusError as err: + logger.error('Failed to wait until server ready: %s', addr) + logger.error(err) + raise + + def __init__(self, address: str | None = None, timeout: int = 30, api_version: str = 'v1', **client_args: t.Any) -> None: + if address is None: + env = os.environ.get('OPENLLM_ENDPOINT') + if env is None: raise ValueError('address must be provided') + address = env + self.__attrs_init__(address, client_args, httpx.Client(base_url=address, timeout=timeout, **client_args), timeout, api_version) # type: ignore[attr-defined] def _metadata(self) -> dict[str, t.Any]: if self.__metadata is None: self.__metadata = self._inner.post(self._build_endpoint('metadata')).json() @@ -39,30 +88,51 @@ def _config(self) -> dict[str, t.Any]: self.__config = {**config, **generation_config} return self.__config - def health(self): - return self._inner.get('/readyz') + def __del__(self) -> None: + self._inner.close() - def _build_endpoint(self, endpoint: str): - return '/' + f'{self.api_version}/{endpoint}' + def _build_endpoint(self, endpoint: str) -> str: + return '/' + f'{self._api_version}/{endpoint}' - def query(self, prompt: str, **attrs: t.Any) -> Response: - req = Request(prompt=self._metadata()['prompt_template'].format(system_message=self._metadata()['system_message'], instruction=prompt), llm_config={**self._config(), **attrs}) - r = self._inner.post(self._build_endpoint('generate'), json=req.json(), **self.client_args) - payload = r.json() - if r.status_code != 200: raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.") - return Response(**payload) + @property + def is_ready(self) -> bool: + return self._state == ServerState.READY def generate(self, prompt: str, **attrs: t.Any) -> Response: return self.query(prompt, **attrs) + def health(self) -> None: + try: + self.wait_until_server_ready(self.address, timeout=self._timeout, **self.client_args) + _object_setattr(self, '_state', ServerState.READY) + except Exception as e: + logger.error('Server is not healthy (Scroll up for traceback)\n%s', e) + _object_setattr(self, '_state', ServerState.CLOSED) + + def query(self, prompt: str, **attrs: t.Any) -> Response: + timeout = attrs.pop('timeout', self._timeout) + _meta, _config = self._metadata(), self._config() + if _meta['prompt_template'] is not None: prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt) + + if not self.is_ready: + self.health() + if not self.is_ready: raise RuntimeError('Server is not ready. Check server logs for more information.') + + with httpx.Client(base_url=self.address, timeout=timeout, **self.client_args) as client: + r = client.post(self._build_endpoint('generate'), json=Request(prompt=prompt, llm_config={**_config, **attrs}).json(), **self.client_args) + if r.status_code != 200: raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.") + return Response(**r.json()) + def generate_stream(self, prompt: str, **attrs: t.Any) -> t.Iterator[StreamResponse]: - req = Request(prompt=self._metadata()['prompt_template'].format(system_message=self._metadata()['system_message'], instruction=prompt), llm_config={**self._config(), **attrs}) - with self._inner.stream('POST', self._build_endpoint('generate_stream'), json=req.json(), **self.client_args) as r: - for payload in r.iter_bytes(): - # Skip line - payload = payload.decode('utf-8') - yield StreamResponse(text=payload) + timeout = attrs.pop('timeout', self._timeout) + _meta, _config = self._metadata(), self._config() + if _meta['prompt_template'] is not None: prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt) + with httpx.Client(base_url=self.address, timeout=timeout, **self.client_args) as client: + with client.stream('POST', self._build_endpoint('generate_stream'), json=Request(prompt=prompt, llm_config={**_config, **attrs}).json(), **self.client_args) as r: + for payload in r.iter_bytes(): + yield StreamResponse(text=payload.decode('utf-8')) # TODO: make it SSE correct for streaming + # Skip line # if payload == b"\n": continue # payload = payload.decode("utf-8") # if payload.startswith("data:"): @@ -72,21 +142,48 @@ def generate_stream(self, prompt: str, **attrs: t.Any) -> t.Iterator[StreamRespo # except Exception as e: print(e) # yield resp - def __del__(self) -> None: - self._inner.close() - -@attr.define +@attr.define(init=False) class AsyncHTTPClient: - address: str = attr.field(validator=_address_validator, converter=lambda addr: addr if '://' in addr else 'http://' + addr) - api_version: str = 'v1' - timeout: int = 30 - client_args: t.Dict[str, t.Any] = attr.field(factory=dict) - __metadata: dict[str, t.Any] = attr.field(default=None) - __config: dict[str, t.Any] = attr.field(default=None) - _inner: httpx.AsyncClient = attr.field(init=False, repr=False) - - def __attrs_post_init__(self) -> None: - self._inner = httpx.AsyncClient(base_url=self.address, timeout=self.timeout, **self.client_args) + address: str = attr.field(validator=_address_validator, converter=_address_converter) + client_args: t.Dict[str, t.Any] = attr.field() + _inner: httpx.AsyncClient = attr.field(repr=False) + + _timeout: int = attr.field(default=30, repr=False) + _api_version: str = attr.field(default='v1', repr=False) + _state: ServerState = attr.field(default=ServerState.CLOSED, repr=False) + + __metadata: dict[str, t.Any] | None = attr.field(default=None, repr=False) + __config: dict[str, t.Any] | None = attr.field(default=None, repr=False) + + @staticmethod + async def wait_until_server_ready(addr: str, timeout: float = 30, check_interval: int = 1, **client_args: t.Any) -> None: + addr = _address_converter(addr) + logger.debug('Wait for server @ %s to be ready', addr) + start = time.monotonic() + while time.monotonic() - start < timeout: + try: + async with httpx.AsyncClient(base_url=addr, **client_args) as sess: + status = (await sess.get('/readyz')).status_code + if status == 200: break + else: await asyncio.sleep(check_interval) + except (httpx.ConnectError, urllib.error.URLError, ConnectionError): + logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval) + await asyncio.sleep(check_interval) + # Try once more and raise for exception + try: + async with httpx.AsyncClient(base_url=addr, **client_args) as sess: + status = (await sess.get('/readyz')).status_code + except httpx.HTTPStatusError as err: + logger.error('Failed to wait until server ready: %s', addr) + logger.error(err) + raise + + def __init__(self, address: str | None = None, timeout: int = 30, api_version: str = 'v1', **client_args: t.Any) -> None: + if address is None: + env = os.environ.get('OPENLLM_ENDPOINT') + if env is None: raise ValueError('address must be provided') + address = env + self.__attrs_init__(address, client_args, httpx.AsyncClient(base_url=address, timeout=timeout, **client_args), timeout, api_version) async def _metadata(self) -> dict[str, t.Any]: if self.__metadata is None: self.__metadata = (await self._inner.post(self._build_endpoint('metadata'))).json() @@ -99,34 +196,49 @@ async def _config(self) -> dict[str, t.Any]: self.__config = {**config, **generation_config} return self.__config - async def health(self): - return await self._inner.get('/readyz') + def _build_endpoint(self, endpoint: str) -> str: + return '/' + f'{self._api_version}/{endpoint}' + + @property + def is_ready(self) -> bool: + return self._state == ServerState.READY + + async def generate(self, prompt: str, **attrs: t.Any) -> Response: + return await self.query(prompt, **attrs) - def _build_endpoint(self, endpoint: str): - return '/' + f'{self.api_version}/{endpoint}' + async def health(self) -> None: + try: + await self.wait_until_server_ready(self.address, timeout=self._timeout, **self.client_args) + _object_setattr(self, '_state', ServerState.READY) + except Exception as e: + logger.error('Server is not healthy (Scroll up for traceback)\n%s', e) + _object_setattr(self, '_state', ServerState.CLOSED) async def query(self, prompt: str, **attrs: t.Any) -> Response: + timeout = attrs.pop('timeout', self._timeout) _meta, _config = await self._metadata(), await self._config() - client = httpx.AsyncClient(base_url=self.address, timeout=self.timeout, **self.client_args) + if _meta['prompt_template'] is not None: prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt) + + if not self.is_ready: + await self.health() + if not self.is_ready: raise RuntimeError('Server is not ready. Check server logs for more information.') + req = Request(prompt=_meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt), llm_config={**_config, **attrs}) - r = await client.post(self._build_endpoint('generate'), json=req.json(), **self.client_args) - payload = r.json() + async with httpx.AsyncClient(base_url=self.address, timeout=timeout, **self.client_args) as client: + r = await client.post(self._build_endpoint('generate'), json=req.json(), **self.client_args) if r.status_code != 200: raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.") - return Response(**payload) - - async def generate(self, prompt: str, **attrs: t.Any) -> Response: - return await self.query(prompt, **attrs) + return Response(**r.json()) async def generate_stream(self, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[StreamResponse, t.Any]: _meta, _config = await self._metadata(), await self._config() - client = httpx.AsyncClient(base_url=self.address, timeout=self.timeout, **self.client_args) - req = Request(prompt=_meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt), llm_config={**_config, **attrs}) - async with client.stream('POST', self._build_endpoint('generate_stream'), json=req.json(), **self.client_args) as r: - async for payload in r.aiter_bytes(): - # Skip line - payload = payload.decode('utf-8') - yield StreamResponse(text=payload) + if _meta['prompt_template'] is not None: prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt) + req = Request(prompt=prompt, llm_config={**_config, **attrs}) + async with httpx.AsyncClient(base_url=self.address, timeout=self._timeout, **self.client_args) as client: + async with client.stream('POST', self._build_endpoint('generate_stream'), json=req.json(), **self.client_args) as r: + async for payload in r.aiter_bytes(): + yield StreamResponse(text=payload.decode('utf-8')) # TODO: make it SSE correct for streaming + # Skip line # if payload == b"\n": continue # payload = payload.decode("utf-8") # if payload.startswith("data:"): diff --git a/openllm-core/src/openllm_core/__init__.py b/openllm-core/src/openllm_core/__init__.py index 16c6b9ba1..126622179 100644 --- a/openllm-core/src/openllm_core/__init__.py +++ b/openllm-core/src/openllm_core/__init__.py @@ -6,7 +6,6 @@ from ._configuration import GenerationConfig as GenerationConfig from ._configuration import LLMConfig as LLMConfig from ._configuration import SamplingParams as SamplingParams -from ._schema import EmbeddingsOutput as EmbeddingsOutput from ._schema import GenerationInput as GenerationInput from ._schema import GenerationOutput as GenerationOutput from ._schema import HfAgentInput as HfAgentInput diff --git a/openllm-core/src/openllm_core/_schema.py b/openllm-core/src/openllm_core/_schema.py index 418b36689..a0b332eb9 100644 --- a/openllm-core/src/openllm_core/_schema.py +++ b/openllm-core/src/openllm_core/_schema.py @@ -69,16 +69,10 @@ class MetadataOutput: model_name: str backend: str configuration: str - supports_embeddings: bool supports_hf_agent: bool prompt_template: str system_message: str -@attr.frozen(slots=True) -class EmbeddingsOutput: - embeddings: t.List[t.List[float]] - num_tokens: int - def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> dict[str, t.Any]: return dict(request_id=request_output.request_id, prompt=request_output.prompt, diff --git a/openllm-core/src/openllm_core/_typing_compat.py b/openllm-core/src/openllm_core/_typing_compat.py index e20064fdb..2fc5ffab6 100644 --- a/openllm-core/src/openllm_core/_typing_compat.py +++ b/openllm-core/src/openllm_core/_typing_compat.py @@ -21,7 +21,6 @@ from bentoml._internal.runner.runner import RunnerMethod from bentoml._internal.runner.strategy import Strategy from openllm._llm import LLM - from openllm_core._schema import EmbeddingsOutput from .utils.lazy import VersionInfo @@ -92,7 +91,6 @@ class LLMRunnable(bentoml.Runnable, t.Generic[M, T]): SUPPORTED_RESOURCES = ('amd.com/gpu', 'nvidia.com/gpu', 'cpu') SUPPORTS_CPU_MULTI_THREADING = True __call__: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]] - embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], EmbeddingsOutput] generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]] generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]] generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Iterator[t.Any]] @@ -108,12 +106,10 @@ class LLMRunner(bentoml.Runner, t.Generic[M, T]): llm: openllm.LLM[M, T] config: openllm.LLMConfig backend: LiteralBackend - supports_embeddings: bool supports_hf_agent: bool has_adapters: bool system_message: str | None prompt_template: str | None - embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], t.Sequence[EmbeddingsOutput]] generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]] generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]] generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Iterator[t.Any]] @@ -137,10 +133,6 @@ def __init__(self, def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ... - @abc.abstractmethod - def embed(self, prompt: str | list[str]) -> EmbeddingsOutput: - ... - def run(self, prompt: str, **attrs: t.Any) -> t.Any: ... diff --git a/openllm-core/src/openllm_core/config/configuration_llama.py b/openllm-core/src/openllm_core/config/configuration_llama.py index cd13bf21b..0af9594ef 100644 --- a/openllm-core/src/openllm_core/config/configuration_llama.py +++ b/openllm-core/src/openllm_core/config/configuration_llama.py @@ -107,8 +107,8 @@ class SamplingParams: presence_penalty: float = 0.5 @property - def default_prompt_template(self, use_llama2_prompt: bool = True) -> str: - return DEFAULT_PROMPT_TEMPLATE('v2' if use_llama2_prompt else 'v1').to_string() + def default_prompt_template(self) -> str: + return DEFAULT_PROMPT_TEMPLATE('v2' if self.use_llama2_prompt else 'v1').to_string() @property def default_system_message(self) -> str: diff --git a/openllm-python/src/openllm/__init__.py b/openllm-python/src/openllm/__init__.py index 92b4c6821..abfca557e 100644 --- a/openllm-python/src/openllm/__init__.py +++ b/openllm-python/src/openllm/__init__.py @@ -15,7 +15,7 @@ from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams from openllm_core._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource -from openllm_core._schema import EmbeddingsOutput as EmbeddingsOutput, GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs +from openllm_core._schema import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig if openllm_core.utils.DEBUG: @@ -45,8 +45,7 @@ "serialisation": ["ggml", "transformers"], "cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"], "_quantisation": ["infer_quantisation_config"], - "_embeddings": ["GenericEmbeddingRunnable"], - "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "EmbeddingsOutput"], + "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"], "_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"], "models.auto": ["MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": [], @@ -66,9 +65,8 @@ if _t.TYPE_CHECKING: from . import bundle as bundle, cli as cli, client as client, models as models, playground as playground, serialisation as serialisation, testing as testing from ._generation import LogitsProcessorList as LogitsProcessorList, StopOnTokens as StopOnTokens, StoppingCriteriaList as StoppingCriteriaList, StopSequenceCriteria as StopSequenceCriteria, prepare_logits_processor as prepare_logits_processor - from ._llm import LLM as LLM, EmbeddingsOutput as EmbeddingsOutput, LLMRunnable as LLMRunnable, LLMRunner as LLMRunner, Runner as Runner + from ._llm import LLM as LLM, LLMRunnable as LLMRunnable, LLMRunner as LLMRunner, Runner as Runner from ._quantisation import infer_quantisation_config as infer_quantisation_config - from ._embeddings import GenericEmbeddingRunnable as GenericEmbeddingRunnable from .cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start, start_grpc as start_grpc from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES from .serialisation import ggml as ggml, transformers as transformers @@ -182,7 +180,7 @@ from .models.opt import TFOPT as TFOPT # NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__ -__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED, "__openllm_migration__": {"LLMEmbeddings": "EmbeddingsOutput"}}) +__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED}) __all__ = __lazy.__all__ __dir__ = __lazy.__dir__ __getattr__ = __lazy.__getattr__ diff --git a/openllm-python/src/openllm/_embeddings.py b/openllm-python/src/openllm/_embeddings.py deleted file mode 100644 index 9f5f3257c..000000000 --- a/openllm-python/src/openllm/_embeddings.py +++ /dev/null @@ -1,79 +0,0 @@ -# See https://github.com/bentoml/sentence-embedding-bento for more information. -from __future__ import annotations -import typing as t - -import transformers - -from huggingface_hub import snapshot_download - -import bentoml -import openllm - -from bentoml._internal.frameworks.transformers import API_VERSION -from bentoml._internal.frameworks.transformers import MODULE_NAME -from bentoml._internal.models.model import ModelOptions -from bentoml._internal.models.model import ModelSignature - -if t.TYPE_CHECKING: - import torch - -_GENERIC_EMBEDDING_ID = 'sentence-transformers/all-MiniLM-L6-v2' -_BENTOMODEL_ID = 'sentence-transformers--all-MiniLM-L6-v2' - -def get_or_download(ids: str = _BENTOMODEL_ID) -> bentoml.Model: - try: - return bentoml.transformers.get(ids) - except bentoml.exceptions.NotFound: - model_signatures = { - k: ModelSignature(batchable=False) - for k in ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search', '__call__') - } - with bentoml.models.create(ids, - module=MODULE_NAME, - api_version=API_VERSION, - options=ModelOptions(), - context=openllm.utils.generate_context(framework_name='transformers'), - labels={ - 'runtime': 'pt', - 'framework': 'openllm' - }, - signatures=model_signatures) as bentomodel: - snapshot_download(_GENERIC_EMBEDDING_ID, - local_dir=bentomodel.path, - local_dir_use_symlinks=False, - ignore_patterns=['*.safetensors', '*.h5', '*.ot', '*.pdf', '*.md', '.gitattributes', 'LICENSE.txt']) - return bentomodel - -class GenericEmbeddingRunnable(bentoml.Runnable): - SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu') - SUPPORTS_CPU_MULTI_THREADING = True - - def __init__(self) -> None: - self.device = 'cuda' if openllm.utils.device_count() > 0 else 'cpu' - self._bentomodel = get_or_download() - self.tokenizer = transformers.AutoTokenizer.from_pretrained(self._bentomodel.path) - self.model = transformers.AutoModel.from_pretrained(self._bentomodel.path) - self.model.to(self.device) - - @bentoml.Runnable.method(batchable=True, batch_dim=0) - def encode(self, sentences: list[str]) -> t.Sequence[openllm.EmbeddingsOutput]: - import torch - import torch.nn.functional as F - encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(self.device) - attention_mask = encoded_input['attention_mask'] - # Compute token embeddings - with torch.no_grad(): - model_output = self.model(**encoded_input) - # Perform pooling and normalize - sentence_embeddings = F.normalize(self.mean_pooling(model_output, attention_mask), p=2, dim=1) - return [openllm.EmbeddingsOutput(embeddings=sentence_embeddings.cpu().numpy(), num_tokens=int(torch.sum(attention_mask).item()))] - - @staticmethod - def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - import torch - # Mean Pooling - Take attention mask into account for correct averaging - token_embeddings = model_output[0] # First element of model_output contains all token embeddings - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) - -__all__ = ['GenericEmbeddingRunnable'] diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 3f303c913..dff0eb6d8 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -21,7 +21,6 @@ from bentoml._internal.models.model import ModelSignature from openllm_core._configuration import FineTuneConfig from openllm_core._configuration import LLMConfig -from openllm_core._schema import EmbeddingsOutput from openllm_core._typing_compat import AdaptersMapping from openllm_core._typing_compat import AdaptersTuple from openllm_core._typing_compat import AdapterType @@ -165,16 +164,6 @@ def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: ''' raise NotImplementedError - def embeddings(self, prompts: list[str]) -> EmbeddingsOutput: - '''The implementation for generating text embeddings from given prompt. - - It takes the prompt and output the embeddings for this given LLM. - - Returns: - The embeddings for the given prompt. - ''' - raise NotImplementedError - class LLMSerialisation(abc.ABC, t.Generic[M, T]): def import_model(self, *args: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: '''Import both model and tokenizer weights into as a BentoML models. @@ -261,8 +250,6 @@ def import_kwargs(self) -> tuple[DictStrAny, DictStrAny] | None: __llm_adapter_map__: t.Optional[ResolvedAdaptersMapping] '''A reference to the the cached LoRA adapter mapping.''' - __llm_supports_embeddings__: bool - '''A boolean to determine whether models does implement ``LLM.embeddings``.''' __llm_supports_generate__: bool '''A boolean to determine whether models does implement ``LLM.generate``.''' __llm_supports_generate_one__: bool @@ -338,10 +325,6 @@ def __getitem__(self, item: t.Literal['tokenizer']) -> T | None: def __getitem__(self, item: t.Literal['adapter_map']) -> ResolvedAdaptersMapping | None: ... - @overload - def __getitem__(self, item: t.Literal['supports_embeddings']) -> bool: - ... - @overload def __getitem__(self, item: t.Literal['supports_generate']) -> bool: ... @@ -876,18 +859,16 @@ def to_runner(self, raise RuntimeError(f'Failed to locate {self._bentomodel}:{err}') from None generate_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=False)) - embeddings_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=True, batch_dim=0)) generate_iterator_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=False)) # NOTE: returning the two langchain API's to the runner - return llm_runner_class(self)(llm_runnable_class(self, embeddings_sig, generate_sig, generate_iterator_sig), + return llm_runner_class(self)(llm_runnable_class(self, generate_sig, generate_iterator_sig), name=self.runner_name, embedded=False, models=models, max_batch_size=max_batch_size, max_latency_ms=max_latency_ms, method_configs=bentoml_cattr.unstructure({ - 'embeddings': embeddings_sig, '__call__': generate_sig, 'generate': generate_sig, 'generate_one': generate_sig, @@ -970,14 +951,14 @@ def generate_iterator(self, past_key_values = out = token = None finish_reason = None for i in range(config['max_new_tokens']): - torch.cuda.synchronize() + if torch.cuda.is_available(): torch.cuda.synchronize() if i == 0: # prefill out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True) else: # decoding out = self.model(torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values) logits = out.logits past_key_values = out.past_key_values - torch.cuda.synchronize() + if torch.cuda.is_available(): torch.cuda.synchronize() if logits_processor: if config['repetition_penalty'] > 1.0: @@ -1139,7 +1120,7 @@ class SetAdapterOutput(t.TypedDict): success: bool message: str -def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate_sig: ModelSignature, generate_iterator_sig: ModelSignature) -> type[LLMRunnable[M, T]]: +def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_iterator_sig: ModelSignature) -> type[LLMRunnable[M, T]]: class _Runnable(bentoml.Runnable): SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu') SUPPORTS_CPU_MULTI_THREADING = True @@ -1159,10 +1140,6 @@ def set_adapter(__self: _Runnable, adapter_name: str) -> None: if adapter_name != 'default': self.model.set_adapter(adapter_name) logger.info('Successfully apply LoRA layer %s', adapter_name) - @bentoml.Runnable.method(**method_signature(embeddings_sig)) # type: ignore - def embeddings(__self: _Runnable, prompt: str | list[str]) -> t.Sequence[EmbeddingsOutput]: - return [self.embeddings([prompt] if isinstance(prompt, str) else prompt)] - @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore def __call__(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) @@ -1303,18 +1280,6 @@ def _wrapped_generate_run(__self: LLMRunner[M, T], prompt: str, **kwargs: t.Any) prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **kwargs) return self.postprocess_generate(prompt, __self.generate.run(prompt, **generate_kwargs), **postprocess_kwargs) - def _wrapped_embeddings_run(__self: LLMRunner[M, T], prompt: str | list[str]) -> EmbeddingsOutput: - '''``llm.embed`` is a light wrapper around runner.embeedings.run(). - - Usage: - - ```python - runner = openllm.Runner('llama', backend='pt') - runner.embed("What is the meaning of life?") - ``` - ''' - return __self.embeddings.run([prompt] if isinstance(prompt, str) else prompt) - def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]: return {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'} @@ -1325,6 +1290,14 @@ def _wrapped_repr_args(__self: LLMRunner[M, T]) -> ReprArgs: yield 'backend', self.__llm_backend__ yield 'llm_tag', self.tag + if self._prompt_template: prompt_template = self._prompt_template.to_string() + elif hasattr(self.config, 'default_prompt_template'): prompt_template = self.config.default_prompt_template + else: prompt_template = None + + if self._system_message: system_message = self._system_message + elif hasattr(self.config, 'default_system_message'): system_message = self.config.default_system_message + else: system_message = None + return types.new_class(self.__class__.__name__ + 'Runner', (bentoml.Runner,), exec_body=lambda ns: ns.update({ 'llm_type': self.llm_type, @@ -1336,17 +1309,15 @@ def _wrapped_repr_args(__self: LLMRunner[M, T]) -> ReprArgs: 'peft_adapters': property(fget=available_adapters), 'download_model': self.save_pretrained, '__call__': _wrapped_generate_run, - 'embed': _wrapped_embeddings_run, '__module__': self.__module__, '__doc__': self.config['env'].start_docstring, '__repr__': ReprMixin.__repr__, '__repr_keys__': property(_wrapped_repr_keys), '__repr_args__': _wrapped_repr_args, - 'supports_embeddings': self['supports_embeddings'], 'supports_hf_agent': self['supports_generate_one'], 'has_adapters': self._adapters_mapping is not None, - 'prompt_template': self._prompt_template.to_string() if self._prompt_template else self.config.default_prompt_template, - 'system_message': self._system_message if self._system_message else self.config.default_system_message, + 'prompt_template': prompt_template, + 'system_message': system_message, })) -__all__ = ['LLMRunner', 'LLMRunnable', 'Runner', 'LLM', 'llm_runner_class', 'llm_runnable_class', 'EmbeddingsOutput'] +__all__ = ['LLMRunner', 'LLMRunnable', 'Runner', 'LLM', 'llm_runner_class', 'llm_runnable_class'] diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py index 366b323d9..08066524e 100644 --- a/openllm-python/src/openllm/_service.py +++ b/openllm-python/src/openllm/_service.py @@ -19,11 +19,6 @@ from starlette.requests import Request from starlette.responses import Response - from bentoml._internal.runner.runner import AbstractRunner - from bentoml._internal.runner.runner import RunnerMethod - from openllm_core._typing_compat import TypeAlias - _EmbeddingMethod: TypeAlias = RunnerMethod[t.Union[bentoml.Runnable, openllm.LLMRunnable[t.Any, t.Any]], [t.List[str]], t.Sequence[openllm.EmbeddingsOutput]] - # The following warnings from bitsandbytes, and probably not that important for users to see warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization') warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization') @@ -34,14 +29,7 @@ adapter_map = svars.adapter_map llm_config = openllm.AutoConfig.for_model(model) runner = openllm.Runner(model, llm_config=llm_config, model_id=model_id, ensure_available=False, adapter_map=orjson.loads(adapter_map)) -generic_embedding_runner = bentoml.Runner(openllm.GenericEmbeddingRunnable, # XXX: remove arg-type once bentoml.Runner is correct set with type - name='llm-generic-embedding', - scheduling_strategy=openllm_core.CascadingResourceStrategy, - max_batch_size=32, - max_latency_ms=300) -runners: list[AbstractRunner] = [runner] -if not runner.supports_embeddings: runners.append(generic_embedding_runner) -svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=runners) +svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner]) _JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None}) @@ -193,7 +181,6 @@ def models_v1() -> t.List[dict[str, t.Any]]: 'model_name': llm_config['model_name'], 'backend': runner.backend, 'configuration': llm_config.model_dump(flatten=True), - 'supports_embeddings': runner.supports_embeddings, 'supports_hf_agent': runner.supports_hf_agent, 'prompt_template': runner.prompt_template, 'system_message': runner.system_message, @@ -204,27 +191,11 @@ def metadata_v1(_: str) -> openllm.MetadataOutput: backend=llm_config['env']['backend_value'], model_id=runner.llm.model_id, configuration=llm_config.model_dump_json().decode(), - supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent, prompt_template=runner.prompt_template, system_message=runner.system_message, ) -@svc.api(route='/v1/embeddings', - input=bentoml.io.JSON.from_sample(['Hey Jude, welcome to the jungle!', 'What is the meaning of life?']), - output=bentoml.io.JSON.from_sample({ - 'embeddings': [ - 0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, - 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, - 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, - 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076 - ], - 'num_tokens': 20 - })) -async def embeddings_v1(phrases: list[str]) -> list[openllm.EmbeddingsOutput]: - embed_call: _EmbeddingMethod = runner.embeddings if runner.supports_embeddings else generic_embedding_runner.encode # type: ignore[type-arg,assignment,valid-type] - return await embed_call.async_run(phrases) - if runner.supports_hf_agent: async def hf_agent(request: Request) -> Response: diff --git a/openllm-python/src/openllm/cli/entrypoint.py b/openllm-python/src/openllm/cli/entrypoint.py index 00b47bdd4..01ff796f6 100644 --- a/openllm-python/src/openllm/cli/entrypoint.py +++ b/openllm-python/src/openllm/cli/entrypoint.py @@ -21,7 +21,6 @@ """ from __future__ import annotations import functools -import http.client import inspect import itertools import logging @@ -112,7 +111,8 @@ from bentoml._internal.bento import BentoStore from bentoml._internal.container import DefaultBuilder - from openllm_core._schema import EmbeddingsOutput + from openllm_client._schemas import Response + from openllm_client._schemas import StreamResponse from openllm_core._typing_compat import LiteralContainerRegistry from openllm_core._typing_compat import LiteralContainerVersionStrategy else: @@ -130,17 +130,20 @@ ''' ServeCommand = t.Literal['serve', 'serve-grpc'] + @attr.define class GlobalOptions: cloud_context: str | None = attr.field(default=None) def with_options(self, **attrs: t.Any) -> Self: return attr.evolve(self, **attrs) + GrpType = t.TypeVar('GrpType', bound=click.Group) _object_setattr = object.__setattr__ _EXT_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), 'extension')) + class Extensions(click.MultiCommand): def list_commands(self, ctx: click.Context) -> list[str]: return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith('.py') and not filename.startswith('__')]) @@ -151,6 +154,7 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None except ImportError: return None return mod.cli + class OpenLLMCommandGroup(BentoMLCommandGroup): NUMBER_OF_COMMON_PARAMS = 5 # parameters in common_params + 1 faked group option header @@ -284,10 +288,12 @@ def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> if rows: with formatter.section(_('Extensions')): formatter.write_dl(rows) + @click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='openllm') -@click.version_option( - None, '--version', '-v', message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}" -) +@click.version_option(None, + '--version', + '-v', + message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}") def cli() -> None: '''\b ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ @@ -301,6 +307,7 @@ def cli() -> None: An open platform for operating large language models in production. Fine-tune, serve, deploy, and monitor any LLMs with ease. ''' + @cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http']) def start_command() -> None: '''Start any LLM as a REST server. @@ -310,6 +317,7 @@ def start_command() -> None: $ openllm -- ... ``` ''' + @cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start-grpc') def start_grpc_command() -> None: '''Start any LLM as a gRPC server. @@ -319,6 +327,7 @@ def start_grpc_command() -> None: $ openllm start-grpc -- ... ``` ''' + _start_mapping = { 'start': { key: start_command_factory(start_command, key, _context_settings=termui.CONTEXT_SETTINGS) for key in CONFIG_MAPPING @@ -327,6 +336,7 @@ def start_grpc_command() -> None: key: start_command_factory(start_grpc_command, key, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=True) for key in CONFIG_MAPPING } } + @cli.command(name='import', aliases=['download']) @model_name_argument @click.argument('model_id', type=click.STRING, default=None, metavar='Optional[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=False) @@ -337,17 +347,9 @@ def start_grpc_command() -> None: @machine_option @backend_option @serialisation_option -def import_command( - model_name: str, - model_id: str | None, - converter: str | None, - model_version: str | None, - output: LiteralOutput, - machine: bool, - backend: LiteralBackend, - quantize: LiteralQuantise | None, - serialisation: LiteralSerialisation | None, -) -> bentoml.Model: +def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend, + quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None, + ) -> bentoml.Model: """Setup LLM interactively. It accepts two positional arguments: `model_name` and `model_id`. The first name determine @@ -402,7 +404,13 @@ def import_command( _serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation']) env = EnvVarMixin(model_name, backend=llm_config.default_backend(), model_id=model_id, quantize=quantize) backend = first_not_none(backend, default=env['backend_value']) - llm = infer_auto_class(backend).for_model(model_name, model_id=env['model_id_value'], llm_config=llm_config, model_version=model_version, ensure_available=False, quantize=env['quantize_value'], serialisation=_serialisation) + llm = infer_auto_class(backend).for_model(model_name, + model_id=env['model_id_value'], + llm_config=llm_config, + model_version=model_version, + ensure_available=False, + quantize=env['quantize_value'], + serialisation=_serialisation) _previously_saved = False try: _ref = openllm.serialisation.get(llm) @@ -434,66 +442,40 @@ def import_command( @workers_per_resource_option(factory=click, build=True) @cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Optimisation options') @quantize_option(factory=cog.optgroup, build=True) -@click.option( - '--enable-features', - multiple=True, - nargs=1, - metavar='FEATURE[,FEATURE]', - help='Enable additional features for building this LLM Bento. Available: {}'.format(', '.join(OPTIONAL_DEPENDENCIES)) -) -@click.option( - '--adapter-id', - default=None, - multiple=True, - metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]', - help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed." -) +@click.option('--enable-features', + multiple=True, + nargs=1, + metavar='FEATURE[,FEATURE]', + help='Enable additional features for building this LLM Bento. Available: {}'.format(', '.join(OPTIONAL_DEPENDENCIES))) +@click.option('--adapter-id', + default=None, + multiple=True, + metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]', + help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.") @click.option('--build-ctx', help='Build context. This is required if --adapter-id uses relative path', default=None) @model_version_option @click.option('--dockerfile-template', default=None, type=click.File(), help='Optional custom dockerfile template to be used with this BentoLLM.') @serialisation_option @container_registry_option -@click.option( - '--container-version-strategy', type=click.Choice(['release', 'latest', 'nightly']), default='release', help="Default container version strategy for the image from '--container-registry'" -) +@click.option('--container-version-strategy', + type=click.Choice(['release', 'latest', 'nightly']), + default='release', + help="Default container version strategy for the image from '--container-registry'") @cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options') -@cog.optgroup.option( - '--containerize', - default=False, - is_flag=True, - type=click.BOOL, - help="Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'." -) +@cog.optgroup.option('--containerize', + default=False, + is_flag=True, + type=click.BOOL, + help="Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'.") @cog.optgroup.option('--push', default=False, is_flag=True, type=click.BOOL, help="Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.") @click.option('--force-push', default=False, is_flag=True, type=click.BOOL, help='Whether to force push.') @click.pass_context -def build_command( - ctx: click.Context, - /, - model_name: str, - model_id: str | None, - bento_version: str | None, - overwrite: bool, - output: LiteralOutput, - quantize: LiteralQuantise | None, - enable_features: tuple[str, ...] | None, - workers_per_resource: float | None, - adapter_id: tuple[str, ...], - build_ctx: str | None, - backend: LiteralBackend, - system_message: str | None, - prompt_template_file: t.IO[t.Any] | None, - machine: bool, - model_version: str | None, - dockerfile_template: t.TextIO | None, - containerize: bool, - push: bool, - serialisation: LiteralSerialisation | None, - container_registry: LiteralContainerRegistry, - container_version_strategy: LiteralContainerVersionStrategy, - force_push: bool, - **attrs: t.Any, -) -> bentoml.Bento: +def build_command(ctx: click.Context, /, model_name: str, model_id: str | None, bento_version: str | None, overwrite: bool, output: LiteralOutput, quantize: LiteralQuantise | None, + enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend, + system_message: str | None, prompt_template_file: t.IO[t.Any] | None, machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool, + push: bool, serialisation: LiteralSerialisation | None, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy, + force_push: bool, **attrs: t.Any, + ) -> bentoml.Bento: '''Package a given models into a Bento. \b @@ -530,7 +512,16 @@ def build_command( if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template - llm = infer_auto_class(env['backend_value']).for_model(model_name, model_id=env['model_id_value'], prompt_template=prompt_template, system_message=system_message, llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs) + llm = infer_auto_class(env['backend_value']).for_model(model_name, + model_id=env['model_id_value'], + prompt_template=prompt_template, + system_message=system_message, + llm_config=llm_config, + ensure_available=True, + model_version=model_version, + quantize=env['quantize_value'], + serialisation=_serialisation, + **attrs) labels = dict(llm.identifying_params) labels.update({'_type': llm.llm_type, '_framework': env['backend_value']}) @@ -575,18 +566,16 @@ def build_command( raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {bento_tag}') from None _previously_built = True except bentoml.exceptions.NotFound: - bento = bundle.create_bento( - bento_tag, - llm_fs, - llm, - workers_per_resource=workers_per_resource, - adapter_map=adapter_map, - quantize=quantize, - extra_dependencies=enable_features, - dockerfile_template=dockerfile_template_path, - container_registry=container_registry, - container_version_strategy=container_version_strategy - ) + bento = bundle.create_bento(bento_tag, + llm_fs, + llm, + workers_per_resource=workers_per_resource, + adapter_map=adapter_map, + quantize=quantize, + extra_dependencies=enable_features, + dockerfile_template=dockerfile_template_path, + container_registry=container_registry, + container_version_strategy=container_version_strategy) except Exception as err: raise err from None @@ -596,12 +585,11 @@ def build_command( termui.echo('\n' + OPENLLM_FIGLET, fg='white') if not _previously_built: termui.echo(f'Successfully built {bento}.', fg='green') elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg='yellow') - termui.echo( - '📖 Next steps:\n\n' + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" + - f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" + - "\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n", - fg='blue', - ) + termui.echo('📖 Next steps:\n\n' + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" + + f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" + + "\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n", + fg='blue', + ) elif output == 'json': termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) else: @@ -688,7 +676,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo data: list[str | tuple[str, str, list[str], str, tuple[LiteralBackend, ...]]] = [] for m, v in json_data.items(): data.extend([(m, v['architecture'], v['model_id'], v['installation'], v['backend'])]) - column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)] + column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)] if len(data) == 0 and len(failed_initialized) > 0: termui.echo('Exception found while parsing models:\n', fg='yellow') @@ -716,14 +704,17 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo if show_available: json_data['local'] = local_models termui.echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2,).decode(), fg='white') ctx.exit(0) + @cli.command() @model_name_argument(required=False) @click.option('-y', '--yes', '--assume-yes', is_flag=True, help='Skip confirmation when deleting a specific model') @click.option('--include-bentos/--no-include-bentos', is_flag=True, default=False, help='Whether to also include pruning bentos.') @inject -def prune_command( - model_name: str | None, yes: bool, include_bentos: bool, model_store: ModelStore = Provide[BentoMLContainer.model_store], bento_store: BentoStore = Provide[BentoMLContainer.bento_store] -) -> None: +def prune_command(model_name: str | None, + yes: bool, + include_bentos: bool, + model_store: ModelStore = Provide[BentoMLContainer.model_store], + bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> None: '''Remove all saved models, (and optionally bentos) built with OpenLLM locally. \b @@ -744,6 +735,7 @@ def prune_command( if delete_confirmed: store.delete(store_item.tag) termui.echo(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.", fg='yellow') + def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, value: list[str] | str | None) -> tuple[str, bool | str] | list[str] | str | None: if value is None: return value @@ -762,6 +754,7 @@ def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, val return key, values[0] else: raise click.BadParameter(f'Invalid option format: {value}') + def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal['json', 'porcelain', 'pretty'] = 'pretty') -> t.Callable[[FC], FC]: options = [ click.option('--endpoint', type=click.STRING, help='OpenLLM Server endpoint, i.e: http://localhost:3000', envvar='OPENLLM_ENDPOINT', default='http://localhost:3000', @@ -770,20 +763,19 @@ def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal output_option(default_value=output_value), ] return compose(*options)(f) if f is not None else compose(*options) + @cli.command() @click.argument('task', type=click.STRING, metavar='TASK') @shared_client_options @click.option('--agent', type=click.Choice(['hf']), default='hf', help='Whether to interact with Agents from given Server endpoint.', show_default=True) @click.option('--remote', is_flag=True, default=False, help='Whether or not to use remote tools (inference endpoints) instead of local ones.', show_default=True) -@click.option( - '--opt', - help="Define prompt options. " - "(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)", - required=False, - multiple=True, - callback=opt_callback, - metavar='ARG=VALUE[,ARG=VALUE]' -) +@click.option('--opt', + help="Define prompt options. " + "(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)", + required=False, + multiple=True, + callback=opt_callback, + metavar='ARG=VALUE[,ARG=VALUE]') def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str: '''Instruct agents interactively for given tasks, from a terminal. @@ -795,66 +787,37 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output: ``` ''' raise click.ClickException("'instruct' is currently disabled") - client = openllm.client.HTTPClient(endpoint, timeout=timeout) + # client = openllm.client.HTTPClient(endpoint, timeout=timeout) + # + # try: + # client.call('metadata') + # except http.client.BadStatusLine: + # raise click.ClickException(f'{endpoint} is neither a HTTP server nor reachable.') from None + # if agent == 'hf': + # _memoized = {k: v[0] for k, v in _memoized.items() if v} + # client._hf_agent.set_stream(logger.info) + # if output != 'porcelain': termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg='magenta') + # result = client.ask_agent(task, agent_type=agent, return_code=False, remote=remote, **_memoized) + # if output == 'json': termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg='white') + # else: termui.echo(result, fg='white') + # return result + # else: + # raise click.BadOptionUsage('agent', f'Unknown agent type {agent}') - try: - client.call('metadata') - except http.client.BadStatusLine: - raise click.ClickException(f'{endpoint} is neither a HTTP server nor reachable.') from None - if agent == 'hf': - _memoized = {k: v[0] for k, v in _memoized.items() if v} - client._hf_agent.set_stream(logger.info) - if output != 'porcelain': termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg='magenta') - result = client.ask_agent(task, agent_type=agent, return_code=False, remote=remote, **_memoized) - if output == 'json': termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg='white') - else: termui.echo(result, fg='white') - return result - else: - raise click.BadOptionUsage('agent', f'Unknown agent type {agent}') -@cli.command() -@shared_client_options(output_value='json') -@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True) -@click.argument('text', type=click.STRING, nargs=-1) -@machine_option -@click.pass_context -def embed_command( - ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, machine: bool -) -> EmbeddingsOutput | None: - '''Get embeddings interactively, from a terminal. - - \b - ```bash - $ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?" - ``` - ''' - client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == 'http' else openllm.client.GrpcClient(endpoint, timeout=timeout) - try: - gen_embed = client.embed(text) - except ValueError: - raise click.ClickException(f'Endpoint {endpoint} does not support embeddings.') from None - if machine: return gen_embed - elif output == 'pretty': - termui.echo('Generated embeddings: ', fg='magenta', nl=False) - termui.echo(gen_embed.embeddings, fg='white') - termui.echo('\nNumber of tokens: ', fg='magenta', nl=False) - termui.echo(gen_embed.num_tokens, fg='white') - elif output == 'json': - termui.echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg='white') - else: - termui.echo(gen_embed.embeddings, fg='white') - ctx.exit(0) @cli.command() @shared_client_options(output_value='porcelain') @click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True) @click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.') @click.argument('prompt', type=click.STRING) -@click.option( - '--sampling-params', help='Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)', required=False, multiple=True, callback=opt_callback, metavar='ARG=VALUE[,ARG=VALUE]' -) +@click.option('--sampling-params', + help='Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)', + required=False, + multiple=True, + callback=opt_callback, + metavar='ARG=VALUE[,ARG=VALUE]') @click.pass_context -def query_command( - ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any -) -> None: +def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, + **attrs: t.Any) -> None: '''Ask a LLM interactively, from a terminal. \b @@ -870,24 +833,32 @@ def query_command( if output != 'porcelain': termui.echo('==Input==\n', fg='white') termui.echo(f'{prompt}', fg=input_fg) - fn = client.generate_stream if stream else client.generate - res = fn(prompt, **{**client._config(), **_memoized}) - if output == 'pretty': - termui.echo('\n\n==Responses==\n', fg='white') - if stream: - for it in res: termui.echo(it.text, fg=generated_fg, nl=False) - else: termui.echo(res.responses[0], fg=generated_fg) - elif output == 'json': - if stream: - for it in res: termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white') - else: termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white') - else: # noqa: PLR5501 - if stream: - for it in res: termui.echo(it.text, fg=generated_fg, nl=False) - else: termui.echo(res.responses, fg='white') + + if stream: + stream_res: t.Iterator[StreamResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized}) + if output == 'pretty': + termui.echo('\n\n==Responses==\n', fg='white') + for it in stream_res: + termui.echo(it.text, fg=generated_fg, nl=False) + elif output == 'json': + for it in stream_res: + termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white') + else: + for it in stream_res: + termui.echo(it.text, fg=generated_fg, nl=False) + else: + res: Response = client.generate(prompt, **{**client._config(), **_memoized}) + if output == 'pretty': + termui.echo('\n\n==Responses==\n', fg='white') + termui.echo(res.responses[0], fg=generated_fg) + elif output == 'json': + termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white') + else: + termui.echo(res.responses, fg='white') ctx.exit(0) @cli.group(cls=Extensions, hidden=True, name='extension') def extension_command() -> None: '''Extension for OpenLLM CLI.''' + if __name__ == '__main__': cli() diff --git a/openllm-python/src/openllm/client.py b/openllm-python/src/openllm/client.py index 6a90bd608..369f44d24 100644 --- a/openllm-python/src/openllm/client.py +++ b/openllm-python/src/openllm/client.py @@ -4,11 +4,6 @@ client = openllm.client.HTTPClient("http://localhost:8080") client.query("What is the difference between gather and scatter?") ``` - -If the server has embedding supports, use it via `client.embed`: -```python -client.embed("What is the difference between gather and scatter?") -``` ''' from __future__ import annotations import typing as t diff --git a/openllm-python/src/openllm/models/chatglm/modeling_chatglm.py b/openllm-python/src/openllm/models/chatglm/modeling_chatglm.py index 0bed146cc..e80f726a2 100644 --- a/openllm-python/src/openllm/models/chatglm/modeling_chatglm.py +++ b/openllm-python/src/openllm/models/chatglm/modeling_chatglm.py @@ -15,17 +15,3 @@ def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, st # Only use half precision if the model is not yet quantized if self.config.use_half_precision: self.model.half() return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config()) - - def embeddings(self, prompts: list[str]) -> openllm.EmbeddingsOutput: - import torch - import torch.nn.functional as F - embeddings: list[list[float]] = [] - num_tokens = 0 - for prompt in prompts: - input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) - with torch.inference_mode(): - outputs = self.model(input_ids, output_hidden_states=True) - data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0) - embeddings.append(data.tolist()) - num_tokens += len(input_ids[0]) - return openllm.EmbeddingsOutput(embeddings=embeddings, num_tokens=num_tokens) diff --git a/openllm-python/src/openllm/models/flan_t5/modeling_flan_t5.py b/openllm-python/src/openllm/models/flan_t5/modeling_flan_t5.py index 601e0749e..45c2be95a 100644 --- a/openllm-python/src/openllm/models/flan_t5/modeling_flan_t5.py +++ b/openllm-python/src/openllm/models/flan_t5/modeling_flan_t5.py @@ -15,17 +15,3 @@ def generate(self, prompt: str, **attrs: t.Any) -> list[str]: do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True) - - def embeddings(self, prompts: list[str]) -> openllm.EmbeddingsOutput: - import torch - import torch.nn.functional as F - embeddings: list[list[float]] = [] - num_tokens = 0 - for prompt in prompts: - input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device) - with torch.inference_mode(): - outputs = self.model(input_ids, decoder_input_ids=input_ids) - data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0) - embeddings.append(data.tolist()) - num_tokens += len(input_ids[0]) - return openllm.EmbeddingsOutput(embeddings=embeddings, num_tokens=num_tokens) diff --git a/openllm-python/src/openllm/models/llama/modeling_llama.py b/openllm-python/src/openllm/models/llama/modeling_llama.py index 6c37d1fa8..1422ee916 100644 --- a/openllm-python/src/openllm/models/llama/modeling_llama.py +++ b/openllm-python/src/openllm/models/llama/modeling_llama.py @@ -12,15 +12,3 @@ class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaToke def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: import torch return {'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {} - - def embeddings(self, prompts: list[str]) -> openllm.EmbeddingsOutput: - import torch - import torch.nn.functional as F - encoding = self.tokenizer(prompts, padding=True, return_tensors='pt').to(self.device) - input_ids, attention_mask = encoding['input_ids'], encoding['attention_mask'] - with torch.inference_mode(): - data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1] - mask = attention_mask.unsqueeze(-1).expand(data.size()).float() - masked_embeddings = data * mask - sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1) - return openllm.EmbeddingsOutput(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item())) diff --git a/tools/run-release-action b/tools/run-release-action index bd0f97b95..5744bda82 100755 --- a/tools/run-release-action +++ b/tools/run-release-action @@ -102,10 +102,6 @@ sleep 5 echo "Building OpenLLM container for ${RELEASE_TAG}..." gh workflow run build.yml -R bentoml/openllm -r "${RELEASE_TAG}" -sleep 5 -echo "Building OpenLLM embedding container for ${RELEASE_TAG}..." -gh workflow run build-embedding.yml -R bentoml/openllm -r "${RELEASE_TAG}" - sleep 5 echo "Building Clojure UI (community-maintained) for ${RELEASE_TAG}..." gh workflow run clojure-frontend.yml -R bentoml/openllm -r "${RELEASE_TAG}"