forked from MehrabRahman/langchaindemo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_hf_llm.py
112 lines (95 loc) · 3.36 KB
/
custom_hf_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Any, List, Mapping, Optional
from baselib import aiutils as aiutils
from baselib import httputils as http
from baselib import baselog as log
from baselib import langchain_utils as lutils
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs.llm_result import LLMResult
import requests
"""
**************************************************
* HFCustomLLM
**************************************************
"""
class HFCustomLLM(LLM):
n: int
name: str
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return self._talkToTheHand(prompt)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"n": self.n, "name": self.name}
def _talkToTheHand(self, prompt:str):
params = self._getParameters1()
response: requests.Response = self._queryModel(prompt,params)
http.understandResponse(response)
text = self._extractGeneratedText(response)
return text
def _queryModel(self, prompt, parameters):
apiKey = aiutils.getHFAPIKey()
hfEndPointUrl = aiutils.getSampleHFEndPoint()
headers = {"Authorization": f"Bearer {apiKey}"}
payload = {
"inputs": prompt,
"parameters": parameters
}
response = requests.post(hfEndPointUrl, headers=headers, json=payload)
return response
def _extractGeneratedText(self, response: requests.Response):
return response.json()[0]['generated_text']
def _getParameters1(self):
return {
"max_new_tokens": 200,
"temperature": 0.6,
"top_p": 0.9,
"do_sample": False,
"return_full_text": False
}
def _getParameters(self):
return {
"max_length": 200
}
def _getTestPrompt(self):
question = "What is the population of Jacksonville, Florida?"
return question
def _getTestPrompt2(self):
question = "What is the population of Jacksonville, Florida?"
context = "As of the most current census, Jacksonville, Florida has a population of 1 million."
prompt = f"""Use the following context to answer the question at the end.
{context}
Question: {question}
"""
return prompt
def selfTest(self: LLM):
#answer: str = self.invoke("All roses are read")
answer: LLMResult = self.generate(["All roses are read"])
lutils.examineTextFrom_HF_LLM_Reply(answer)
text = lutils.getSingleText_From_HF_LLM_Reply(answer)
log.ph("Answer from self test LLM",text)
"""
**************************************************
* EOF_Class: HFCustomLLM
**************************************************
"""
def testHFCustomLLM():
llm = HFCustomLLM(n=10, name="Satya")
llm.selfTest()
def localTest():
log.ph1("Starting local test")
testHFCustomLLM()
log.ph1("End local test")
if __name__ == '__main__':
localTest()