-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
48 lines (38 loc) · 1.4 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import subprocess
import time
import httpx
from cog import BasePredictor, ConcatenateIterator
#from sse import receive_sse
#from utils import (
# maybe_download_tarball_with_pget,
# StreamingTokenStopSequenceHandler,
#)
import pytriton.utils.distribution
TRITONSERVER_DIST_DIR = pytriton.utils.distribution.get_root_module_path() / "tritonserver"
class Predictor(BasePredictor):
def setup(self) -> None:
# # launch triton server
# # python3 scripts/launch_triton_server.py --world_size=1 --model_repo=/src/tensorrtllm_backend/triton_model
subprocess.Popen(
[
str(TRITONSERVER_DIST_DIR / "bin" / "tritonserver"),
"--backend-dir", str(TRITONSERVER_DIST_DIR / "backends"),
"--model-repository", "/src/triton_model_repo",
]
)
# Health check Triton until it is ready
while True:
try:
response = httpx.get("http://localhost:8000/v2/health/ready")
if response.status_code == 200:
print("Triton is ready.")
break
except httpx.RequestError:
pass
time.sleep(1)
self.client = httpx.AsyncClient(timeout=10)
async def predict(self) -> str:
return "hello!"