Skip to content

Commit

Permalink
feat: Enable end users to pass model instances of HuggingFaceHub
Browse files Browse the repository at this point in the history
  • Loading branch information
shkamboj1 committed May 4, 2024
1 parent 98dec36 commit 7599234
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
63 changes: 63 additions & 0 deletions examples/huggingfacehub/smart_scraper_huggingfacehub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
"""

import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings




## required environment variable in .env
#HUGGINGFACEHUB_API_TOKEN
load_dotenv()

HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
# ************************************************
# Initialize the model instances
# ************************************************

repo_id = "mistralai/Mistral-7B-Instruct-v0.2"

llm_model_instance = HuggingFaceEndpoint(
repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
)




embedder_model_instance = HuggingFaceInferenceAPIEmbeddings(
api_key=HUGGINGFACEHUB_API_TOKEN, model_name="sentence-transformers/all-MiniLM-l6-v2"
)

# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************

graph_config = {
"llm": {"model_instance": llm_model_instance},
"embeddings": {"model_instance": embedder_model_instance}
}

smart_scraper_graph = SmartScraperGraph(
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
# also accepts a string with the already downloaded HTML code
source="https://www.hmhco.com/event",
config=graph_config
)

result = smart_scraper_graph.run()
print(result)

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))


7 changes: 7 additions & 0 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def _set_model_token(self, llm):
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")

elif 'HuggingFaceEndpoint' in str(type(llm)):
if 'mistral' in llm.repo_id:
try:
self.model_token = models_tokens['mistral'][llm.repo_id]
except KeyError:
raise KeyError("Model not supported")


def _create_llm(self, llm_config: dict, chat=False) -> object:
Expand Down
3 changes: 3 additions & 0 deletions scrapegraphai/helpers/models_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,8 @@
"mistral.mistral-large-2402-v1:0": 32768,
"cohere.embed-english-v3": 512,
"cohere.embed-multilingual-v3": 512
},
"mistral": {
"mistralai/Mistral-7B-Instruct-v0.2": 32000
}
}
4 changes: 4 additions & 0 deletions scrapegraphai/nodes/rag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings

from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace, Bedrock
from .base_node import BaseNode
Expand Down Expand Up @@ -95,6 +96,9 @@ def execute(self, state: dict) -> dict:
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
embeddings = embedding_model
elif isinstance(embedding_model, HuggingFaceInferenceAPIEmbeddings):
embeddings = embedding_model

elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):
Expand Down

0 comments on commit 7599234

Please sign in to comment.