-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add csv scraper and xml scraper multi
- Loading branch information
1 parent
fa9722d
commit b408655
Showing
5 changed files
with
361 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
""" | ||
Basic example of scraping pipeline using CSVScraperMultiGraph from CSV documents | ||
""" | ||
|
||
import os | ||
import pandas as pd | ||
from scrapegraphai.graphs import CSVScraperMultiGraph | ||
from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info | ||
|
||
# ************************************************ | ||
# Read the CSV file | ||
# ************************************************ | ||
|
||
FILE_NAME = "inputs/username.csv" | ||
curr_dir = os.path.dirname(os.path.realpath(__file__)) | ||
file_path = os.path.join(curr_dir, FILE_NAME) | ||
|
||
text = pd.read_csv(file_path) | ||
|
||
# ************************************************ | ||
# Define the configuration for the graph | ||
# ************************************************ | ||
|
||
graph_config = { | ||
"llm": { | ||
"model": "ollama/llama3", | ||
"temperature": 0, | ||
"format": "json", # Ollama needs the format to be specified explicitly | ||
# "model_tokens": 2000, # set context length arbitrarily | ||
"base_url": "http://localhost:11434", | ||
}, | ||
"embeddings": { | ||
"model": "ollama/nomic-embed-text", | ||
"temperature": 0, | ||
"base_url": "http://localhost:11434", | ||
}, | ||
"verbose": True, | ||
} | ||
|
||
# ************************************************ | ||
# Create the CSVScraperMultiGraph instance and run it | ||
# ************************************************ | ||
|
||
csv_scraper_graph = CSVScraperMultiGraph( | ||
prompt="List me all the last names", | ||
source=[str(text), str(text)], | ||
config=graph_config | ||
) | ||
|
||
result = csv_scraper_graph.run() | ||
print(result) | ||
|
||
# ************************************************ | ||
# Get graph execution info | ||
# ************************************************ | ||
|
||
graph_exec_info = csv_scraper_graph.get_execution_info() | ||
print(prettify_exec_info(graph_exec_info)) | ||
|
||
# Save to json or csv | ||
convert_to_csv(result, "result") | ||
convert_to_json(result, "result") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
Basic example of scraping pipeline using XMLScraperMultiGraph from XML documents | ||
""" | ||
|
||
import os | ||
from dotenv import load_dotenv | ||
from scrapegraphai.graphs import XMLScraperMultiGraph | ||
from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info | ||
load_dotenv() | ||
|
||
# ************************************************ | ||
# Read the XML file | ||
# ************************************************ | ||
|
||
FILE_NAME = "inputs/books.xml" | ||
curr_dir = os.path.dirname(os.path.realpath(__file__)) | ||
file_path = os.path.join(curr_dir, FILE_NAME) | ||
|
||
with open(file_path, 'r', encoding="utf-8") as file: | ||
text = file.read() | ||
|
||
# ************************************************ | ||
# Define the configuration for the graph | ||
# ************************************************ | ||
|
||
graph_config = { | ||
"llm": { | ||
"model": "ollama/llama3", | ||
"temperature": 0, | ||
"format": "json", # Ollama needs the format to be specified explicitly | ||
# "model_tokens": 2000, # set context length arbitrarily | ||
"base_url": "http://localhost:11434", | ||
}, | ||
"embeddings": { | ||
"model": "ollama/nomic-embed-text", | ||
"temperature": 0, | ||
"base_url": "http://localhost:11434", | ||
}, | ||
"verbose": True, | ||
} | ||
|
||
# ************************************************ | ||
# Create the XMLScraperMultiGraph instance and run it | ||
# ************************************************ | ||
|
||
xml_scraper_graph = XMLScraperMultiGraph( | ||
prompt="List me all the authors, title and genres of the books", | ||
source=[text, text], # Pass the content of the file, not the file object | ||
config=graph_config | ||
) | ||
|
||
result = xml_scraper_graph.run() | ||
print(result) | ||
|
||
# ************************************************ | ||
# Get graph execution info | ||
# ************************************************ | ||
|
||
graph_exec_info = xml_scraper_graph.get_execution_info() | ||
print(prettify_exec_info(graph_exec_info)) | ||
|
||
# Save to json or csv | ||
convert_to_csv(result, "result") | ||
convert_to_json(result, "result") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
CSVScraperMultiGraph Module | ||
""" | ||
|
||
from copy import copy, deepcopy | ||
from typing import List, Optional | ||
|
||
from .base_graph import BaseGraph | ||
from .abstract_graph import AbstractGraph | ||
from .csv_scraper_graph import CSVScraperGraph | ||
|
||
from ..nodes import ( | ||
GraphIteratorNode, | ||
MergeAnswersNode | ||
) | ||
|
||
|
||
class CSVScraperMultiGraph(AbstractGraph): | ||
""" | ||
CSVScraperMultiGraph is a scraping pipeline that scrapes a list of URLs and generates answers to a given prompt. | ||
It only requires a user prompt and a list of URLs. | ||
Attributes: | ||
prompt (str): The user prompt to search the internet. | ||
llm_model (dict): The configuration for the language model. | ||
embedder_model (dict): The configuration for the embedder model. | ||
headless (bool): A flag to run the browser in headless mode. | ||
verbose (bool): A flag to display the execution information. | ||
model_token (int): The token limit for the language model. | ||
Args: | ||
prompt (str): The user prompt to search the internet. | ||
source (List[str]): The source of the graph. | ||
config (dict): Configuration parameters for the graph. | ||
schema (Optional[str]): The schema for the graph output. | ||
Example: | ||
>>> search_graph = MultipleSearchGraph( | ||
... "What is Chioggia famous for?", | ||
... {"llm": {"model": "gpt-3.5-turbo"}} | ||
... ) | ||
>>> result = search_graph.run() | ||
""" | ||
|
||
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[str] = None): | ||
|
||
self.max_results = config.get("max_results", 3) | ||
|
||
if all(isinstance(value, str) for value in config.values()): | ||
self.copy_config = copy(config) | ||
else: | ||
self.copy_config = deepcopy(config) | ||
|
||
super().__init__(prompt, config, source, schema) | ||
|
||
def _create_graph(self) -> BaseGraph: | ||
""" | ||
Creates the graph of nodes representing the workflow for web scraping and searching. | ||
Returns: | ||
BaseGraph: A graph instance representing the web scraping and searching workflow. | ||
""" | ||
|
||
# ************************************************ | ||
# Create a SmartScraperGraph instance | ||
# ************************************************ | ||
|
||
smart_scraper_instance = CSVScraperGraph( | ||
prompt="", | ||
source="", | ||
config=self.copy_config, | ||
) | ||
|
||
# ************************************************ | ||
# Define the graph nodes | ||
# ************************************************ | ||
|
||
graph_iterator_node = GraphIteratorNode( | ||
input="user_prompt & jsons", | ||
output=["results"], | ||
node_config={ | ||
"graph_instance": smart_scraper_instance, | ||
} | ||
) | ||
|
||
merge_answers_node = MergeAnswersNode( | ||
input="user_prompt & results", | ||
output=["answer"], | ||
node_config={ | ||
"llm_model": self.llm_model, | ||
"schema": self.schema | ||
} | ||
) | ||
|
||
return BaseGraph( | ||
nodes=[ | ||
graph_iterator_node, | ||
merge_answers_node, | ||
], | ||
edges=[ | ||
(graph_iterator_node, merge_answers_node), | ||
], | ||
entry_point=graph_iterator_node | ||
) | ||
|
||
def run(self) -> str: | ||
""" | ||
Executes the web scraping and searching process. | ||
Returns: | ||
str: The answer to the prompt. | ||
""" | ||
inputs = {"user_prompt": self.prompt, "jsons": self.source} | ||
self.final_state, self.execution_info = self.graph.execute(inputs) | ||
|
||
return self.final_state.get("answer", "No answer found.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
XMLScraperMultiGraph Module | ||
""" | ||
|
||
from copy import copy, deepcopy | ||
from typing import List, Optional | ||
|
||
from .base_graph import BaseGraph | ||
from .abstract_graph import AbstractGraph | ||
from .xml_scraper_graph import XMLScraperGraph | ||
|
||
from ..nodes import ( | ||
GraphIteratorNode, | ||
MergeAnswersNode | ||
) | ||
|
||
|
||
class XMLScraperMultiGraph(AbstractGraph): | ||
""" | ||
XMLScraperMultiGraph is a scraping pipeline that scrapes a list of URLs and | ||
generates answers to a given prompt. | ||
It only requires a user prompt and a list of URLs. | ||
Attributes: | ||
prompt (str): The user prompt to search the internet. | ||
llm_model (dict): The configuration for the language model. | ||
embedder_model (dict): The configuration for the embedder model. | ||
headless (bool): A flag to run the browser in headless mode. | ||
verbose (bool): A flag to display the execution information. | ||
model_token (int): The token limit for the language model. | ||
Args: | ||
prompt (str): The user prompt to search the internet. | ||
source (List[str]): The source of the graph. | ||
config (dict): Configuration parameters for the graph. | ||
schema (Optional[str]): The schema for the graph output. | ||
Example: | ||
>>> search_graph = MultipleSearchGraph( | ||
... "What is Chioggia famous for?", | ||
... {"llm": {"model": "gpt-3.5-turbo"}} | ||
... ) | ||
>>> result = search_graph.run() | ||
""" | ||
|
||
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[str] = None): | ||
|
||
self.max_results = config.get("max_results", 3) | ||
|
||
if all(isinstance(value, str) for value in config.values()): | ||
self.copy_config = copy(config) | ||
else: | ||
self.copy_config = deepcopy(config) | ||
|
||
super().__init__(prompt, config, source, schema) | ||
|
||
def _create_graph(self) -> BaseGraph: | ||
""" | ||
Creates the graph of nodes representing the workflow for web scraping and searching. | ||
Returns: | ||
BaseGraph: A graph instance representing the web scraping and searching workflow. | ||
""" | ||
|
||
# ************************************************ | ||
# Create a SmartScraperGraph instance | ||
# ************************************************ | ||
|
||
smart_scraper_instance = XMLScraperGraph( | ||
prompt="", | ||
source="", | ||
config=self.copy_config, | ||
) | ||
|
||
# ************************************************ | ||
# Define the graph nodes | ||
# ************************************************ | ||
|
||
graph_iterator_node = GraphIteratorNode( | ||
input="user_prompt & jsons", | ||
output=["results"], | ||
node_config={ | ||
"graph_instance": smart_scraper_instance, | ||
} | ||
) | ||
|
||
merge_answers_node = MergeAnswersNode( | ||
input="user_prompt & results", | ||
output=["answer"], | ||
node_config={ | ||
"llm_model": self.llm_model, | ||
"schema": self.schema | ||
} | ||
) | ||
|
||
return BaseGraph( | ||
nodes=[ | ||
graph_iterator_node, | ||
merge_answers_node, | ||
], | ||
edges=[ | ||
(graph_iterator_node, merge_answers_node), | ||
], | ||
entry_point=graph_iterator_node | ||
) | ||
|
||
def run(self) -> str: | ||
""" | ||
Executes the web scraping and searching process. | ||
Returns: | ||
str: The answer to the prompt. | ||
""" | ||
inputs = {"user_prompt": self.prompt, "jsons": self.source} | ||
self.final_state, self.execution_info = self.graph.execute(inputs) | ||
|
||
return self.final_state.get("answer", "No answer found.") |