forked from MehrabRahman/langchaindemo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e35355a
commit 20bf0cc
Showing
12 changed files
with
458 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
What year is the state of the union that is mentioned here? | ||
What is the Chips Act? | ||
Who all are attending this state of the union? | ||
Who all are attending this state of the union? And can you put them in a list? | ||
what are the key topics covered in the state of the union? Can you put them in a mark down list? | ||
What are the key topics covered in the state of the union? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#""" | ||
#************************************************* | ||
#* Approach 1: streamlit_ui.py | ||
#************************************************* | ||
#""" | ||
def initialize_state1(): | ||
writeIntro() | ||
|
||
"""Initialize the state variable.""" | ||
if 'totalResponseText' not in st.session_state: | ||
st.session_state.totalResponseText = [] | ||
|
||
def writeOutput1(): | ||
# Display the accumulated text | ||
st.write("Accumulated Text:") | ||
for line in st.session_state.totalResponseText: | ||
st.write(line) | ||
|
||
def process_input1(input_text): | ||
"""Append the input text to the totalResponseText state variable.""" | ||
st.session_state.totalResponseText.append(input_text) | ||
|
||
def addToLog1(text: str): | ||
st.session_state.totalResponseText.append(text) | ||
|
||
def clearLog1(): | ||
st.session_state.totalResponseText.clear() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
|
||
""" | ||
************************************************* | ||
* Base libs | ||
************************************************* | ||
""" | ||
from baselib import baselog as log | ||
from customllms.custom_fb_hf_llm import FB_HFCustomLLM | ||
from typing import Tuple | ||
""" | ||
************************************************* | ||
* LangChain and HF | ||
************************************************* | ||
""" | ||
from langchain.chains import LLMChain | ||
from langchain.prompts import PromptTemplate | ||
from vectorlib.database import DatabaseRepo | ||
from vectorlib.database import Database | ||
from langchain_core.vectorstores import VectorStore | ||
from langchain_core.vectorstores import VectorStoreRetriever | ||
from langchain_core.runnables import RunnablePassthrough | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.language_models.llms import LLM | ||
|
||
""" | ||
************************************************* | ||
* Local stuff: siblings | ||
************************************************* | ||
""" | ||
from ui.wizard import Wizard | ||
|
||
""" | ||
************************************************* | ||
* Class Wizard | ||
************************************************* | ||
log.turnOffDebug() | ||
""" | ||
class LangChainHFWizard(Wizard): | ||
llm: LLM | ||
prompt: PromptTemplate | ||
retriever: VectorStoreRetriever | ||
|
||
def __init__(self): | ||
#get the prompt | ||
self.prompt = self._getTemplate() | ||
|
||
#get the llm | ||
self.llm = DatabaseRepo.get_fbhf_LLM() | ||
|
||
#vector db stuff | ||
db: VectorStore = DatabaseRepo.getSOFUDatabase().get() | ||
self.retriever = db.as_retriever(search_kwargs={"k": 1}) | ||
self.chain = _getChain(self.llm, self.prompt, self.retriever) | ||
|
||
def _getTemplate(self) -> PromptTemplate: | ||
template = """Instructions: Use only the following context to answer the question. | ||
Context: {context} | ||
Question: {question} | ||
""" | ||
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | ||
return prompt | ||
|
||
# | ||
# Interface | ||
def question(self, question: str) -> Tuple[str, str]: | ||
answer = self.chain.invoke(question) | ||
return (question, answer) | ||
|
||
""" | ||
************************************************* | ||
* Some utility funcs | ||
************************************************* | ||
""" | ||
def _getChain(llm: LLM, prompt: PromptTemplate, retriever: VectorStoreRetriever): | ||
chain = ( | ||
{"context": retriever, "question": RunnablePassthrough()} | ||
| prompt | ||
| llm | ||
| StrOutputParser() | ||
) | ||
return chain | ||
|
||
|
||
def test(): | ||
wizard = LangChainHFWizard() | ||
q, a = wizard.question("What is the Chips Act?") | ||
log.ph("Final answer", a) | ||
|
||
def localTest(): | ||
#log.turnOffDebug() | ||
log.ph1("Starting local test") | ||
test() | ||
log.ph1("End local test") | ||
|
||
if __name__ == '__main__': | ||
localTest() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
|
||
from ui.wizard import ( | ||
Wizard, | ||
FakeWizard | ||
) | ||
from ui.LangChainHFWizard import LangChainHFWizard | ||
|
||
class WizardServices(): | ||
class_wizard: Wizard = FakeWizard() | ||
class_real_wizard: Wizard = LangChainHFWizard() | ||
|
||
@staticmethod | ||
def getWizard() -> Wizard: | ||
return WizardServices.class_real_wizard | ||
|
||
@staticmethod | ||
def getFakeWizard() -> Wizard: | ||
return WizardServices.class_wizard | ||
|
||
@staticmethod | ||
def getRealWizard() -> Wizard: | ||
return WizardServices.class_real_wizard |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
import streamlit as st | ||
|
||
# Using HTML with inline CSS to change text color | ||
red_text_html = '<p style="color: red;">This text is red.</p>' | ||
|
||
def sendToScreen(html: str): | ||
st.markdown(html, unsafe_allow_html=True) | ||
|
||
def writeNote(text: str): | ||
note = f"<p style='color:red;'>{text}</p>" | ||
sendToScreen(note) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
|
||
from baselib import baselog as log | ||
from ui.ui_state import ApplicationState | ||
|
||
from ui.questions import ( | ||
Question, | ||
QuestionRepo | ||
) | ||
|
||
def _getMenuItem(question: Question): | ||
# - [item1](/?arg1='item1') | ||
name = question.brief_description | ||
id = question.id | ||
mi = f"1. [{name}](/?question={id})" | ||
return mi | ||
|
||
def getQuestionMenu(questionRepo: QuestionRepo): | ||
menuStr = f"# Questions" | ||
qr = questionRepo | ||
qlist = qr.getQuestionList() | ||
for item in qlist: | ||
question: Question = item | ||
qurl = _getMenuItem(question) | ||
menuStr = f"{menuStr}\n{qurl}" | ||
return menuStr | ||
|
||
|
||
def test(): | ||
q = Question.getASampleQuestion() | ||
log.info(_getMenuItem(q)) | ||
|
||
def testMenu(): | ||
qrepo = QuestionRepo.getSampleRepo() | ||
menu = getQuestionMenu(qrepo) | ||
log.ph("Menu",menu) | ||
|
||
def localTest(): | ||
log.ph1("Starting local test") | ||
testMenu() | ||
log.ph1("End local test") | ||
|
||
if __name__ == '__main__': | ||
localTest() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.