Skip to content

Commit

Permalink
support model revision and tokenizer revision
Browse files Browse the repository at this point in the history
Signed-off-by: Lize Cai <lize.cai@sap.com>
  • Loading branch information
lizzzcai committed Mar 29, 2024
1 parent 8c722b6 commit 58a9988
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
5 changes: 5 additions & 0 deletions python/huggingfaceserver/huggingfaceserver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def list_of_strings(arg):
help='A URI pointer to the model binary')
parser.add_argument('--model_id', required=False,
help='Huggingface model id')
parser.add_argument('--model_revision', required=False, default=None,
help='Huggingface model revision')
parser.add_argument('--tokenizer_revision', required=False, default=None,
help='Huggingface tokenizer revision')
parser.add_argument('--max_length', type=int, default=None,
help='max sequence length for the tokenizer')
parser.add_argument('--disable_lower_case', action='store_true',
Expand All @@ -57,6 +61,7 @@ def list_of_strings(arg):
engine_args = None
if _vllm and not args.disable_vllm:
args.model = args.model_dir or args.model_id
args.revision = args.model_revision
engine_args = AsyncEngineArgs.from_cli_args(args)
predictor_config = PredictorConfig(args.predictor_host, args.predictor_protocol,
args.predictor_use_ssl,
Expand Down
29 changes: 17 additions & 12 deletions python/huggingfaceserver/huggingfaceserver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(self, model_name: str, kwargs,
self.model_dir = kwargs.get('model_dir', None)
if not self.model_id and not self.model_dir:
self.model_dir = "/mnt/models"
self.model_revision = kwargs.get('model_revision', None)
self.tokenizer_revision = kwargs.get('tokenizer_revision', None)
self.do_lower_case = not kwargs.get('disable_lower_case', False)
self.add_special_tokens = not kwargs.get('disable_special_tokens', False)
self.max_length = kwargs.get('max_length', None)
Expand All @@ -91,8 +93,7 @@ def infer_task_from_model_architecture(model_config: str):
raise ValueError(f"Task couldn't be inferred from {architecture}. Please manually set `task` option.")

@staticmethod
def infer_vllm_supported_from_model_architecture(model_config_path: str):
model_config = AutoConfig.from_pretrained(model_config_path)
def infer_vllm_supported_from_model_architecture(model_config: str):
architecture = model_config.architectures[0]
model_cls = ModelRegistry.load_model_cls(architecture)
if model_cls is None:
Expand All @@ -101,18 +102,22 @@ def infer_vllm_supported_from_model_architecture(model_config_path: str):

def load(self) -> bool:
model_id_or_path = self.model_id
revision = self.model_revision
tokenizer_revision = self.tokenizer_revision
if self.model_dir:
model_id_or_path = pathlib.Path(Storage.download(self.model_dir))
# TODO Read the mapping file, index to object name

model_config = AutoConfig.from_pretrained(model_id_or_path, revision=revision)

if self.use_vllm and self.device == torch.device("cuda"): # vllm needs gpu
if self.infer_vllm_supported_from_model_architecture(model_id_or_path):
if self.infer_vllm_supported_from_model_architecture(model_config):
logger.info("supported model by vLLM")
self.vllm_engine_args.tensor_parallel_size = torch.cuda.device_count()
self.vllm_engine = AsyncLLMEngine.from_engine_args(self.vllm_engine_args)
self.ready = True
return self.ready

model_config = AutoConfig.from_pretrained(model_id_or_path)

if not self.task:
self.task = self.infer_task_from_model_architecture(model_config)

Expand All @@ -126,7 +131,7 @@ def load(self) -> bool:
self.device_map = "auto"
# load huggingface tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_id_or_path, do_lower_case=self.do_lower_case, device_map=self.device_map)
model_id_or_path, revision=tokenizer_revision, do_lower_case=self.do_lower_case, device_map=self.device_map)
if not self.tokenizer.pad_token:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
logger.info(f"successfully loaded tokenizer for task: {self.task}")
Expand All @@ -135,19 +140,19 @@ def load(self) -> bool:
if not self.predictor_host:
if self.task == MLTask.sequence_classification.value:
self.model = AutoModelForSequenceClassification.from_pretrained(
model_id_or_path, device_map=self.device_map)
model_id_or_path, revision=revision, device_map=self.device_map)
elif self.task == MLTask.question_answering.value:
self.model = AutoModelForQuestionAnswering.from_pretrained(
model_id_or_path, device_map=self.device_map)
model_id_or_path, revision=revision, device_map=self.device_map)
elif self.task == MLTask.token_classification.value:
self.model = AutoModelForTokenClassification.from_pretrained(
model_id_or_path, device_map=self.device_map)
model_id_or_path, revision=revision, device_map=self.device_map)
elif self.task == MLTask.fill_mask.value:
self.model = AutoModelForMaskedLM.from_pretrained(model_id_or_path, device_map=self.device_map)
self.model = AutoModelForMaskedLM.from_pretrained(model_id_or_path, revision=revision, device_map=self.device_map)
elif self.task == MLTask.text_generation.value:
self.model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map=self.device_map)
self.model = AutoModelForCausalLM.from_pretrained(model_id_or_path, revision=revision, device_map=self.device_map)
elif self.task == MLTask.text2text_generation.value:
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id_or_path, device_map=self.device_map)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id_or_path, revision=revision, device_map=self.device_map)
else:
raise ValueError(f"Unsupported task {self.task}. Please check the supported `task` option.")
self.model.eval()
Expand Down
12 changes: 12 additions & 0 deletions python/huggingfaceserver/huggingfaceserver/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def test_bert():
"The capital of [MASK] is paris."]}, headers={}))
assert response == {"predictions": ["paris", "france"]}

def test_model_revision():
model = HuggingfaceModel("bert-base-uncased",
{"model_id": "bert-base-uncased",
"model_revision":"main",
"tokenizer_revision":"main",
"disable_lower_case": False}
)
model.load()

response = asyncio.run(model({"instances": ["The capital of France is [MASK].",
"The capital of [MASK] is paris."]}, headers={}))
assert response == {"predictions": ["paris", "france"]}

def test_bert_predictor_host(httpx_mock: HTTPXMock):
httpx_mock.add_response(json={"outputs": [{"name": "OUTPUT__0", "shape": [1, 9, 758],
Expand Down

0 comments on commit 58a9988

Please sign in to comment.