Skip to content

Commit

Permalink
[Deploy] Support Triton Inference Server Image
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael-Jin committed Dec 6, 2023
1 parent f5452aa commit ab61b6b
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 31 deletions.
8 changes: 8 additions & 0 deletions devops/dockerfile/fedml-inference/triton/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:22.01-py3
FROM ${BASE_IMAGE}

ENV DATA_CACHE_FOLDER=""
VOLUME [ DATA_CACHE_FOLDER ]

CMD /bin/bash ${BOOTSTRAP_DIR};\
tritonserver --model-repository=${DATA_CACHE_FOLDER};
75 changes: 75 additions & 0 deletions python/examples/deploy/triton/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Prerequite

1. Inside config.yaml, change the
`inference_image_name: "fedml/fedml-triton"`
to your docker image name

2. On your host machine, if you have a model model_repository that you want to mount to the container.
You need to change
`data_cache_dir: "/home/raphael/Triton/server/docs/examples/model_repository"`
to your own local directory

3. The http port for triton inference server is default to 8000, if you would prefer another port:
change
`port_inside_container: 8000`
to your inference port

# Create a new model cards with a configuration file
Note that $model_name need to be the same with the name of model in triton server.
e.g.
```
+----------------------+---------+--------+
| Model | Version | Status |
+----------------------+---------+--------+
| densenet_onnx | 1 | READY |
| inception_graphdef | 1 | READY |
| simple | 1 | READY |
| simple_dyna_sequence | 1 | READY |
| simple_identity | 1 | READY |
| simple_int8 | 1 | READY |
| simple_sequence | 1 | READY |
| simple_string | 1 | READY |
+----------------------+---------+--------+
```
if you want to serve `simple`, then: `model_name=simple`

```sh
cd FedML/python/fedml/serving/example/triton/
fedml model create --name $model_name --config_file config.yaml
```

## On-premsie Deploy
Register an account on FedML website: https://nexus.fedml.ai

You will have a user id and api key, which can be found in the profile page.

- Devices Login
```sh
fedml login $Your_UserId_or_ApiKey
```
You will see your FedML Edge Master and Worker ID in the terminal,
for example:
```
Congratulations, your device is connected to the FedML MLOps platform successfully!
Your FedML Edge ID is 32314, unique device ID is 0xxxxxxxx@MacOS.Edge.Device,
master deploy ID is 31240, worker deploy ID is 31239
```
Here the master id is 31240, and worker ID is 31239.


- Push model card
```sh
fedml model push --name $model_name
```

- OPT1: Deploy - CLI
```sh
fedml model deploy --name my_first_model --master_ids $master_id --worker_ids $client_id
```
- Result

See the deployment result in https://nexus.fedml.ai

- OPT2: Deploy - UI

Follow the instructions on https://nexus.fedml.ai
7 changes: 7 additions & 0 deletions python/examples/deploy/triton/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
workspace: "."

# Related to the triton server
inference_image_name: "fedml/fedml-triton" # Replace with your own image name
data_cache_dir: "/home/raphael/Triton/server/docs/examples/model_repository"
port_inside_container: 8000
use_triton: true
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,7 @@ def build_model(self, model_name):
config = yaml.safe_load(f)
main_entry_file = config.get("entry_point", "")
if main_entry_file == "":
print("The entry_point is missing in the model config file.")
return ""
print("[Warning] The entry_point is missing in the model config file.")
except:
print("You model repository is missing file {}, you should add it.".format(
ClientConstants.MODEL_REQUIRED_MODEL_BIN_FILE))
Expand Down Expand Up @@ -404,8 +403,7 @@ def push_model(self, model_name, user_id, user_api_key, model_storage_url=None,
config = yaml.safe_load(f)
main_entry_file = config.get("entry_point", "")
if main_entry_file == "":
print("The entry_point is missing in the model config file.")
return "", ""
print("[Warning] The entry_point is missing in the model config file.")
except:
print("You model repository is missing file {}, you should add it.".format(
ClientConstants.MODEL_REQUIRED_MODEL_BIN_FILE))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,17 @@ def start_deployment(end_point_id, end_point_name, model_id, model_version,
docker_registry_user_password = config.get("docker_registry_user_password", "")
docker_registry = config.get("docker_registry", "")

port_inside_container = int(config.get("port_inside_container", 2345))
use_triton = config.get("use_triton", False)
if use_triton:
inference_type = "triton"
else:
inference_type = "default"

if src_code_dir == "":
raise Exception("Please indicate source_code_dir in the fedml_model_config.yaml")
if relative_entry == "":
raise Exception("Please indicate main_entry in the fedml_model_config.yaml")
logging.warning("You missed main_entry in the fedml_model_config.yaml")

if inference_engine == ClientConstants.INFERENCE_ENGINE_TYPE_INT_TRITON:
# configuration passed by user in the Cli
Expand Down Expand Up @@ -399,13 +406,13 @@ def start_deployment(end_point_id, end_point_name, model_id, model_version,
image=inference_image_name,
name=default_server_container_name,
volumes=volumns,
ports=[2345], # port open inside the container
ports=[port_inside_container], # port open inside the container
# entrypoint=["python3", relative_entry],
environment=environment,
host_config=client.api.create_host_config(
binds=binds,
port_bindings={
2345: usr_indicated_worker_port # Could be either None or a port number
port_inside_container: usr_indicated_worker_port # Could be either None or a port number
},
device_requests=device_requests,
# mem_limit = "8g", # Could also be configured in the docker desktop setting
Expand All @@ -425,7 +432,7 @@ def start_deployment(end_point_id, end_point_name, model_id, model_version,
break
else:
# Find the random port
port_info = client.api.port(new_container.get("Id"), 2345)
port_info = client.api.port(new_container.get("Id"), port_inside_container)
inference_http_port = port_info[0]["HostPort"]
logging.info("inference_http_port: {}".format(inference_http_port))
break
Expand All @@ -437,15 +444,15 @@ def start_deployment(end_point_id, end_point_name, model_id, model_version,
# Logging the info from the container
log_deployment_result(end_point_id, model_id, default_server_container_name,
ClientConstants.CMD_TYPE_RUN_DEFAULT_SERVER,
running_model_name, inference_engine, inference_http_port, inference_type="default",
inference_model_name, inference_engine, inference_http_port, inference_type,
retry_interval=10, deploy_attempt_threshold=usr_indicated_retry_cnt,
request_input_example=request_input_example, infer_host=infer_host,
enable_custom_image=enable_custom_image)

# Check if the inference server is ready
inference_output_url, running_model_version, ret_model_metadata, ret_model_config = \
get_model_info(running_model_name, inference_engine, inference_http_port,
infer_host, inference_type="default", request_input_example=request_input_example,
get_model_info(inference_model_name, inference_engine, inference_http_port,
infer_host, False, inference_type, request_input_example=request_input_example,
enable_custom_image=enable_custom_image)

if inference_output_url == "":
Expand Down Expand Up @@ -694,7 +701,8 @@ def log_deployment_result(end_point_id, model_id, cmd_container_name, cmd_type,

def is_client_inference_container_ready(infer_url_host, inference_http_port, inference_model_name, local_infer_url,
inference_type="default", model_version="", request_input_example=None):
logging.info(f"Inference type: {inference_type}, infer_url_host {infer_url_host}")
logging.info(f"Inference type: {inference_type}, infer_url_host {infer_url_host}, \
inference_http_port: {inference_http_port}, local_infer_url {local_infer_url}")

if inference_type == "default":
default_client_container_ready_url = "http://{}:{}/ready".format("0.0.0.0", inference_http_port)
Expand All @@ -716,20 +724,16 @@ def is_client_inference_container_ready(infer_url_host, inference_http_port, inf
model_metadata["type"] = "default"
return "http://{}:{}/predict".format(infer_url_host, inference_http_port), None, model_metadata, None
else:
triton_client = http_client.InferenceServerClient(url=local_infer_url, verbose=False)
wait_count = 0
while True:
if not triton_client.is_model_ready(
model_name=inference_model_name, model_version=model_version
):
logging.info(f"model {inference_model_name} not yet ready")
time.sleep(1)
wait_count += 1
if wait_count >= 15:
return "", model_version, {}, {}
else:
break

triton_server_url = "{}:{}".format(infer_url_host, inference_http_port)
if model_version == "" or model_version is None:
model_version = ClientConstants.INFERENCE_MODEL_VERSION
logging.info(f"triton_server_url: {triton_server_url} model_version: {model_version} model_name: {inference_model_name}")
triton_client = http_client.InferenceServerClient(url=triton_server_url, verbose=False)
if not triton_client.is_model_ready(
model_name=inference_model_name, model_version=model_version
):
return "", model_version, {}, {}
logging.info(f"Model {inference_model_name} is ready, start to get model metadata...")
model_metadata = triton_client.get_model_metadata(model_name=inference_model_name, model_version=model_version)
model_config = triton_client.get_model_config(model_name=inference_model_name, model_version=model_version)
version_list = model_metadata.get("versions", None)
Expand All @@ -749,8 +753,6 @@ def is_client_inference_container_ready(infer_url_host, inference_http_port, inf

def get_model_info(model_name, inference_engine, inference_http_port, infer_host="127.0.0.1", is_hg_model=False,
inference_type="default", request_input_example=None, enable_custom_image=False):
if enable_custom_image:
return f"http://{infer_host}:{inference_http_port}/api/v1/predict", "", {"inputs": {}}, {}

if model_name is None:
return "", "", {}, {}
Expand All @@ -773,15 +775,15 @@ def get_model_info(model_name, inference_engine, inference_http_port, infer_host


def run_http_inference_with_curl_request(inference_url, inference_input_list, inference_output_list,
inference_type="default"):
inference_type="default", engine_type="default"):
model_inference_result = {}
if inference_type == "default":
model_api_headers = {'Content-Type': 'application/json', 'Connection': 'close',
'Accept': 'application/json'}
else:
model_api_headers = {'Content-Type': 'application/json', 'Connection': 'close',
'Accept': inference_type}
if inference_type == "default" or inference_type == "image/png" or inference_type == "application/json":
if engine_type == "default":
model_inference_json = inference_input_list
else: # triton
model_inference_json = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def callback_deployment_result_message(self, topic=None, payload=None):
payload_json["input_json"] = {"end_point_name": end_point_name,
"model_name": model_name,
"token": str(token),
"inputs": ret_inputs,
"inputs": {"inputs": ret_inputs}, # Nested inputs
"outputs": model_metadata["outputs"]}
payload_json["output_json"] = model_metadata["outputs"]
FedMLModelCache.get_instance(self.redis_addr, self.redis_port). \
Expand Down

0 comments on commit ab61b6b

Please sign in to comment.