Skip to content

Commit

Permalink
fix(yapf): align weird new lines break [generated] [skip ci] (#284)
Browse files Browse the repository at this point in the history
fix(yapf): align weird new lines break

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm authored Sep 1, 2023
1 parent 3e45530 commit b7af776
Show file tree
Hide file tree
Showing 91 changed files with 812 additions and 1,679 deletions.
5 changes: 1 addition & 4 deletions cz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ def run_cz(dir: str, package: str):
with tokenize.open(filepath) as file_:
tokens = [t for t in tokenize.generate_tokens(file_.readline) if t.type in TOKEN_WHITELIST]
token_count, line_count = len(tokens), len(set([t.start[0] for t in tokens]))
table.append([
filepath.replace(os.path.join(dir, 'src'), ''), line_count,
token_count / line_count if line_count != 0 else 0
])
table.append([filepath.replace(os.path.join(dir, 'src'), ''), line_count, token_count / line_count if line_count != 0 else 0])
print(tabulate([headers, *sorted(table, key=lambda x: -x[1])], headers='firstrow', floatfmt='.1f') + '\n')
for dir_name, group in itertools.groupby(sorted([(x[0].rsplit('/', 1)[0], x[1]) for x in table]), key=lambda x: x[0]):
print(f'{dir_name:35s} : {sum([x[1] for x in group]):6d}')
Expand Down
16 changes: 5 additions & 11 deletions examples/langchain-chains-demo/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,11 @@ def gen_llm(model_name: str, model_id: str | None = None) -> OpenLLM:
def download(_: bentoml.Context):
llm.runner.download_model()

SAMPLE_INPUT = Query(
industry="SAAS",
product_name="BentoML",
keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"],
llm_config=llm.runner.config.model_dump(),
)
SAMPLE_INPUT = Query(industry="SAAS",
product_name="BentoML",
keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"],
llm_config=llm.runner.config.model_dump())

@svc.api(input=JSON.from_sample(sample=SAMPLE_INPUT), output=Text())
def generate(query: Query):
return chain.run({
"industry": query.industry,
"product_name": query.product_name,
"keywords": ", ".join(query.keywords)
})
return chain.run({"industry": query.industry, "product_name": query.product_name, "keywords": ", ".join(query.keywords)})
80 changes: 24 additions & 56 deletions openllm-client/src/openllm_client/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def query(self, prompt: str, *, return_response: t.Literal['attrs'], **attrs: t.
...

@abc.abstractmethod
def query(self,
prompt: str,
return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed',
**attrs: t.Any) -> t.Any:
def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
raise NotImplementedError

# NOTE: Scikit interface
Expand All @@ -84,8 +81,7 @@ def predict(self, prompt: str, *, return_response: t.Literal['raw'], **attrs: t.

@overload
@abc.abstractmethod
def predict(self, prompt: str, *, return_response: t.Literal['attrs'],
**attrs: t.Any) -> openllm_core.GenerationOutput:
def predict(self, prompt: str, *, return_response: t.Literal['attrs'], **attrs: t.Any) -> openllm_core.GenerationOutput:
...

@abc.abstractmethod
Expand All @@ -95,14 +91,12 @@ def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
@functools.cached_property
def _hf_agent(self) -> transformers.HfAgent:
if not is_transformers_available():
raise RuntimeError(
"transformers is required to use HF agent. Install with 'pip install \"openllm-client[agents]\"'.")
raise RuntimeError("transformers is required to use HF agent. Install with 'pip install \"openllm-client[agents]\"'.")
if not self.supports_hf_agent:
raise RuntimeError(f'{self.model_name} ({self.backend}) does not support running HF agent.')
if not is_transformers_supports_agent():
raise RuntimeError(
"Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'"
)
"Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
import transformers
return transformers.HfAgent(urljoin(self._address, '/hf/agent'))

Expand Down Expand Up @@ -183,13 +177,7 @@ def inner(self) -> BentoClient:
return BentoClient.from_url(self._address)

# Agent integration
def ask_agent(self,
task: str,
*,
return_code: bool = False,
remote: bool = False,
agent_type: LiteralString = 'hf',
**attrs: t.Any) -> t.Any:
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = 'hf', **attrs: t.Any) -> t.Any:
if agent_type == 'hf': return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")

Expand Down Expand Up @@ -223,20 +211,13 @@ def inner(self) -> AsyncBentoClient:
return ensure_exec_coro(AsyncBentoClient.from_url(self._address))

# Agent integration
async def ask_agent(self,
task: str,
*,
return_code: bool = False,
remote: bool = False,
agent_type: LiteralString = 'hf',
**attrs: t.Any) -> t.Any:
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = 'hf', **attrs: t.Any) -> t.Any:
if agent_type == 'hf': return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")

async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if not is_transformers_supports_agent():
raise RuntimeError(
'This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0')
raise RuntimeError('This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0')
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
from transformers.tools.agents import clean_code_for_run
from transformers.tools.agents import get_tool_creation_code
Expand Down Expand Up @@ -272,31 +253,23 @@ async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
self._hf_agent.log(f'\n\n==Code generated by the agent==\n{code}')
if not return_code:
self._hf_agent.log('\n\n==Result==')
self._hf_agent.cached_tools = resolve_tools(code,
self._hf_agent.toolbox,
remote=remote,
cached_tools=self._hf_agent.cached_tools)
self._hf_agent.cached_tools = resolve_tools(code, self._hf_agent.toolbox, remote=remote, cached_tools=self._hf_agent.cached_tools)
return evaluate(code, self._hf_agent.cached_tools, state=kwargs.copy())
else:
tool_code = get_tool_creation_code(code, self._hf_agent.toolbox, remote=remote)
return f'{tool_code}\n{code}'

class BaseClient(_Client):

def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
raise NotImplementedError

def embed(self, prompt: t.Sequence[str] | str) -> openllm_core.EmbeddingsOutput:
return openllm_core.EmbeddingsOutput(
**self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt)))
return openllm_core.EmbeddingsOutput(**self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt)))

def predict(self, prompt: str, **attrs: t.Any) -> openllm_core.GenerationOutput | DictStrAny | str:
return self.query(prompt, **attrs)

def query(self,
prompt: str,
return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed',
**attrs: t.Any) -> t.Any:
def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
return_raw_response = attrs.pop('return_raw_response', None)
if return_raw_response is not None:
logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
Expand All @@ -306,32 +279,27 @@ def query(self,
logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
if return_attrs is True: return_response = 'attrs'
use_default_prompt_template = attrs.pop('use_default_prompt_template', False)
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(
prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
r = openllm_core.GenerationOutput(**self.call(
'generate',
openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(
**generate_kwargs)).model_dump()))
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(prompt,
use_default_prompt_template=use_default_prompt_template,
**attrs)
r = openllm_core.GenerationOutput(
**self.call('generate',
openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)).model_dump()))
if return_response == 'attrs': return r
elif return_response == 'raw': return bentoml_cattr.unstructure(r)
else: return self.config.postprocess_generate(prompt, r.responses, **postprocess_kwargs)

class BaseAsyncClient(_AsyncClient):

async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
raise NotImplementedError

async def embed(self, prompt: t.Sequence[str] | str) -> openllm_core.EmbeddingsOutput:
return openllm_core.EmbeddingsOutput(
**(await self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt))))
return openllm_core.EmbeddingsOutput(**(await self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt))))

async def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
return await self.query(prompt, **attrs)

async def query(self,
prompt: str,
return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed',
**attrs: t.Any) -> t.Any:
async def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
return_raw_response = attrs.pop('return_raw_response', None)
if return_raw_response is not None:
logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
Expand All @@ -341,12 +309,12 @@ async def query(self,
logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
if return_attrs is True: return_response = 'attrs'
use_default_prompt_template = attrs.pop('use_default_prompt_template', False)
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(
prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
r = openllm_core.GenerationOutput(**(await self.call(
'generate',
openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(
**generate_kwargs)).model_dump())))
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(prompt,
use_default_prompt_template=use_default_prompt_template,
**attrs)
r = openllm_core.GenerationOutput(
**(await self.call('generate',
openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)).model_dump())))
if return_response == 'attrs': return r
elif return_response == 'raw': return bentoml_cattr.unstructure(r)
else: return self.config.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
Loading

0 comments on commit b7af776

Please sign in to comment.