Skip to content

Commit

Permalink
update bing_search.py
Browse files Browse the repository at this point in the history
  • Loading branch information
imClumsyPanda committed May 21, 2023
1 parent f986b75 commit 9c422cc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 22 deletions.
21 changes: 10 additions & 11 deletions agent/bing_search.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
#coding=utf8

import os
from langchain.utilities import BingSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY


env_bing_key = os.environ.get("BING_SUBSCRIPTION_KEY")
env_bing_url = os.environ.get("BING_SEARCH_URL")


def search(text, result_len=3):
if not (env_bing_key and env_bing_url):
return [{"snippet":"please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env inof not fould", "link":"https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper()
def bing_search(text, result_len=3):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env inof not fould",
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
bing_search_url=BING_SEARCH_URL)
return search.results(text, result_len)


if __name__ == "__main__":
r = search('python')
r = bing_search('python')
print(r)
53 changes: 42 additions & 11 deletions chains/local_doc_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from configs.model_config import *
import datetime
from textsplitter import ChineseTextSplitter
from typing import List, Tuple
from typing import List, Tuple, Dict
from langchain.docstore.document import Document
import numpy as np
from utils import torch_gc
Expand All @@ -18,6 +18,8 @@
from models.loader.args import parser
from models.loader import LoaderCheckPoint
import models.shared as shared
from agent import bing_search
from langchain.docstore.document import Document


def load_file(filepath, sentence_size=SENTENCE_SIZE):
Expand Down Expand Up @@ -58,8 +60,9 @@ def write_check_file(filepath, docs):
fout.close()


def generate_prompt(related_docs: List[str], query: str,
prompt_template=PROMPT_TEMPLATE) -> str:
def generate_prompt(related_docs: List[str],
query: str,
prompt_template: str = PROMPT_TEMPLATE, ) -> str:
context = "\n".join([doc.page_content for doc in related_docs])
prompt = prompt_template.replace("{question}", query).replace("{context}", context)
return prompt
Expand Down Expand Up @@ -137,6 +140,16 @@ def similarity_search_with_score_by_vector(
return docs


def search_result2docs(search_results):
docs = []
for result in search_results:
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={"source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else ""})
docs.append(doc)
return docs


class LocalDocQA:
llm: BaseAnswer = None
embeddings: object = None
Expand Down Expand Up @@ -262,7 +275,6 @@ def get_knowledge_based_answer(self, query, vs_path, chat_history=[], streaming:
"source_documents": related_docs_with_score}
yield response, history


# query 查询内容
# vs_path 知识库路径
# chunk_conent 是否启用上下文关联
Expand All @@ -288,11 +300,26 @@ def get_knowledge_based_conent_test(self, query, vs_path, chunk_conent,
"source_documents": related_docs_with_score}
return response, prompt

def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING):
results = bing_search(query)
result_docs = search_result2docs(results)
prompt = generate_prompt(result_docs, query)

for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
response = {"query": query,
"result": resp,
"source_documents": result_docs}
yield response, history


if __name__ == "__main__":
# 初始化消息
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])

args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
Expand All @@ -304,13 +331,17 @@ def get_knowledge_based_conent_test(self, query, vs_path, chunk_conent,
query = "本项目使用的embedding模型是什么,消耗多少显存"
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
last_print_len = 0
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
vs_path=vs_path,
chat_history=[],
streaming=True):
logger.info(resp["result"][last_print_len:], end="", flush=True)
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
# vs_path=vs_path,
# chat_history=[],
# streaming=True):
for resp, history in local_doc_qa.get_search_result_based_answer(query=query,
chat_history=[],
streaming=True):
print(resp["result"][last_print_len:], end="", flush=True)
last_print_len = len(resp["result"])
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n"""
source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http")
else os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n"""
# f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in
enumerate(resp["source_documents"])]
Expand Down

0 comments on commit 9c422cc

Please sign in to comment.