Skip to content

Commit

Permalink
Final demo version
Browse files Browse the repository at this point in the history
  • Loading branch information
SatyaKomatineni committed Feb 15, 2024
1 parent e35355a commit 20bf0cc
Show file tree
Hide file tree
Showing 12 changed files with 458 additions and 134 deletions.
12 changes: 11 additions & 1 deletion planning/planning.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,14 @@ methods:

Can you also write test function that tests the repo with a few questions?
is there a better way to do this?
how could I have specified with fewer details in the future?
how could I have specified with fewer details in the future?

"""
*************************************************
* wizard ui
*************************************************
"""
Create an abstract class
name: Wizard
abstract methods:
question() returns a pair (question, answer)
6 changes: 6 additions & 0 deletions planning/prepped_questions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
What year is the state of the union that is mentioned here?
What is the Chips Act?
Who all are attending this state of the union?
Who all are attending this state of the union? And can you put them in a list?
what are the key topics covered in the state of the union? Can you put them in a mark down list?
What are the key topics covered in the state of the union?
27 changes: 27 additions & 0 deletions planning/trash.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#"""
#*************************************************
#* Approach 1: streamlit_ui.py
#*************************************************
#"""
def initialize_state1():
writeIntro()

"""Initialize the state variable."""
if 'totalResponseText' not in st.session_state:
st.session_state.totalResponseText = []

def writeOutput1():
# Display the accumulated text
st.write("Accumulated Text:")
for line in st.session_state.totalResponseText:
st.write(line)

def process_input1(input_text):
"""Append the input text to the totalResponseText state variable."""
st.session_state.totalResponseText.append(input_text)

def addToLog1(text: str):
st.session_state.totalResponseText.append(text)

def clearLog1():
st.session_state.totalResponseText.clear()
96 changes: 96 additions & 0 deletions ui/LangChainHFWizard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@

"""
*************************************************
* Base libs
*************************************************
"""
from baselib import baselog as log
from customllms.custom_fb_hf_llm import FB_HFCustomLLM
from typing import Tuple
"""
*************************************************
* LangChain and HF
*************************************************
"""
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from vectorlib.database import DatabaseRepo
from vectorlib.database import Database
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.language_models.llms import LLM

"""
*************************************************
* Local stuff: siblings
*************************************************
"""
from ui.wizard import Wizard

"""
*************************************************
* Class Wizard
*************************************************
log.turnOffDebug()
"""
class LangChainHFWizard(Wizard):
llm: LLM
prompt: PromptTemplate
retriever: VectorStoreRetriever

def __init__(self):
#get the prompt
self.prompt = self._getTemplate()

#get the llm
self.llm = DatabaseRepo.get_fbhf_LLM()

#vector db stuff
db: VectorStore = DatabaseRepo.getSOFUDatabase().get()
self.retriever = db.as_retriever(search_kwargs={"k": 1})
self.chain = _getChain(self.llm, self.prompt, self.retriever)

def _getTemplate(self) -> PromptTemplate:
template = """Instructions: Use only the following context to answer the question.
Context: {context}
Question: {question}
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
return prompt

#
# Interface
def question(self, question: str) -> Tuple[str, str]:
answer = self.chain.invoke(question)
return (question, answer)

"""
*************************************************
* Some utility funcs
*************************************************
"""
def _getChain(llm: LLM, prompt: PromptTemplate, retriever: VectorStoreRetriever):
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return chain


def test():
wizard = LangChainHFWizard()
q, a = wizard.question("What is the Chips Act?")
log.ph("Final answer", a)

def localTest():
#log.turnOffDebug()
log.ph1("Starting local test")
test()
log.ph1("End local test")

if __name__ == '__main__':
localTest()
22 changes: 22 additions & 0 deletions ui/WizardServices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

from ui.wizard import (
Wizard,
FakeWizard
)
from ui.LangChainHFWizard import LangChainHFWizard

class WizardServices():
class_wizard: Wizard = FakeWizard()
class_real_wizard: Wizard = LangChainHFWizard()

@staticmethod
def getWizard() -> Wizard:
return WizardServices.class_real_wizard

@staticmethod
def getFakeWizard() -> Wizard:
return WizardServices.class_wizard

@staticmethod
def getRealWizard() -> Wizard:
return WizardServices.class_real_wizard
12 changes: 12 additions & 0 deletions ui/html_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

import streamlit as st

# Using HTML with inline CSS to change text color
red_text_html = '<p style="color: red;">This text is red.</p>'

def sendToScreen(html: str):
st.markdown(html, unsafe_allow_html=True)

def writeNote(text: str):
note = f"<p style='color:red;'>{text}</p>"
sendToScreen(note)
43 changes: 43 additions & 0 deletions ui/question_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

from baselib import baselog as log
from ui.ui_state import ApplicationState

from ui.questions import (
Question,
QuestionRepo
)

def _getMenuItem(question: Question):
# - [item1](/?arg1='item1')
name = question.brief_description
id = question.id
mi = f"1. [{name}](/?question={id})"
return mi

def getQuestionMenu(questionRepo: QuestionRepo):
menuStr = f"# Questions"
qr = questionRepo
qlist = qr.getQuestionList()
for item in qlist:
question: Question = item
qurl = _getMenuItem(question)
menuStr = f"{menuStr}\n{qurl}"
return menuStr


def test():
q = Question.getASampleQuestion()
log.info(_getMenuItem(q))

def testMenu():
qrepo = QuestionRepo.getSampleRepo()
menu = getQuestionMenu(qrepo)
log.ph("Menu",menu)

def localTest():
log.ph1("Starting local test")
testMenu()
log.ph1("End local test")

if __name__ == '__main__':
localTest()
43 changes: 42 additions & 1 deletion ui/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ def __init__(self, id, brief_description, full_description, answer=None):
self.full_description = full_description
self.answer = answer

@staticmethod
def getASampleQuestion():
return Question(3, "Brief description 3", "Full description 3", "answer")

def addAnswer(self, answer):
"""Sets the answer for the question."""
self.answer = answer
Expand All @@ -18,11 +22,31 @@ def addAnswer(self, answer):
class QuestionRepo:
def __init__(self):
self.questions = OrderedDict()
self.curid = 1

def addQuestion(self, question):
def addQuestion(self, question: Question):
"""Adds a Question object to the questions dictionary."""
self.questions[question.id] = question

def addStringAsQuestion(self, question: str):
if not question:
log.warn("Empty question. ignoring it")
return
if self.isADuplicate(question):
log.warn("Duplicate question. ignoring it.")
return
questionObj = Question(self.curid,question,question)
self.addQuestion(questionObj)
self.curid += 1

def isADuplicate(self, question:str):
qlist = self.getQuestionList()
for q in qlist:
tq: Question = q
if tq.brief_description == question:
return True
return False

def clear(self):
"""Empties the questions dictionary."""
self.questions.clear()
Expand All @@ -34,6 +58,23 @@ def getQuestionList(self):
def getQuestion(self, question_id):
"""Returns the question for the given question_id, or None if not found."""
return self.questions.get(question_id)

@staticmethod
def getSampleRepo():
return _getSampleRepo()

def _getSampleRepo():
# Create Question instances
question1 = Question(1, "Brief description 1", "Full description 1")
question2 = Question(2, "Brief description 2", "Full description 2", "Answer 2")
question3 = Question(3, "Brief description 3", "Full description 3")

# Create a QuestionRepo instance and add questions
repo = QuestionRepo()
repo.addQuestion(question1)
repo.addQuestion(question2)
repo.addQuestion(question3)
return repo

def test_question_repo():
# Create Question instances
Expand Down
Loading

0 comments on commit 20bf0cc

Please sign in to comment.