Skip to content

Commit

Permalink
feat(pydantic): added pydantic output schema
Browse files Browse the repository at this point in the history
  • Loading branch information
PeriniM committed Jun 4, 2024
1 parent 1d217e4 commit 376f758
Show file tree
Hide file tree
Showing 23 changed files with 165 additions and 125 deletions.
63 changes: 63 additions & 0 deletions examples/openai/search_graph_schema_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Example of Search Graph
"""

import os
from dotenv import load_dotenv
load_dotenv()

from scrapegraphai.graphs import SearchGraph
from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info

from pydantic import BaseModel, Field
from typing import List

# ************************************************
# Define the output schema for the graph
# ************************************************

class Dish(BaseModel):
name: str = Field(description="The name of the dish")
description: str = Field(description="The description of the dish")

class Dishes(BaseModel):
dishes: List[Dish]

# ************************************************
# Define the configuration for the graph
# ************************************************

openai_key = os.getenv("OPENAI_APIKEY")

graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"max_results": 2,
"verbose": True,
}

# ************************************************
# Create the SearchGraph instance and run it
# ************************************************

search_graph = SearchGraph(
prompt="List me Chioggia's famous dishes",
config=graph_config,
schema=Dishes
)

result = search_graph.run()
print(result)

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = search_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))

# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")
29 changes: 11 additions & 18 deletions examples/openai/smart_scraper_schema_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import os, json
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from typing import List

from scrapegraphai.graphs import SmartScraperGraph

load_dotenv()
Expand All @@ -12,22 +15,12 @@
# Define the output schema for the graph
# ************************************************

schema= """
{
"Projects": [
"Project #":
{
"title": "...",
"description": "...",
},
"Project #":
{
"title": "...",
"description": "...",
}
]
}
"""
class Project(BaseModel):
title: str = Field(description="The title of the project")
description: str = Field(description="The description of the project")

class Projects(BaseModel):
projects: List[Project]

# ************************************************
# Define the configuration for the graph
Expand All @@ -51,9 +44,9 @@
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the projects with their description",
source="https://perinim.github.io/projects/",
schema=schema,
schema=Projects,
config=graph_config
)

result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))
print(result)
5 changes: 3 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""

from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union
import uuid
from pydantic import BaseModel

from langchain_aws import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
Expand Down Expand Up @@ -62,7 +63,7 @@ class AbstractGraph(ABC):
"""

def __init__(self, prompt: str, config: dict,
source: Optional[str] = None, schema: Optional[str] = None):
source: Optional[str] = None, schema: Optional[BaseModel] = None):

self.prompt = prompt
self.source = source
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/csv_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand All @@ -20,7 +21,7 @@ class CSVScraperGraph(AbstractGraph):
information from web pages using a natural language model to interpret and answer prompts.
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
"""
Initializes the CSVScraperGraph with a prompt, source, and configuration.
"""
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/deep_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -56,7 +57,7 @@ class DeepScraperGraph(AbstractGraph):
)
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):

super().__init__(prompt, config, source, schema)

Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/json_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -44,7 +45,7 @@ class JSONScraperGraph(AbstractGraph):
>>> result = json_scraper.run()
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "json" if source.endswith("json") else "json_dir"
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/omni_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -52,7 +53,7 @@ class OmniScraperGraph(AbstractGraph):
)
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):

self.max_images = 5 if config is None else config.get("max_images", 5)

Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/omni_search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from copy import copy, deepcopy
from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -43,7 +44,7 @@ class OmniSearchGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):

self.max_results = config.get("max_results", 3)

Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/pdf_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -46,7 +47,7 @@ class PDFScraperGraph(AbstractGraph):
>>> result = pdf_scraper.run()
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/script_creator_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -46,7 +47,7 @@ class ScriptCreatorGraph(AbstractGraph):
>>> result = script_creator.run()
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):

self.library = config['library']

Expand Down
8 changes: 6 additions & 2 deletions scrapegraphai/graphs/search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from copy import copy, deepcopy
from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -42,14 +43,16 @@ class SearchGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):

self.max_results = config.get("max_results", 3)

if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)

self.copy_schema = deepcopy(schema)

super().__init__(prompt, config, schema)

Expand All @@ -68,7 +71,8 @@ def _create_graph(self) -> BaseGraph:
smart_scraper_instance = SmartScraperGraph(
prompt="",
source="",
config=self.copy_config
config=self.copy_config,
schema=self.copy_schema
)

# ************************************************
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/smart_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -48,7 +49,7 @@ class SmartScraperGraph(AbstractGraph):
)
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "url" if source.startswith("http") else "local_dir"
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/smart_scraper_multi_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from copy import copy, deepcopy
from typing import List, Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -42,7 +43,7 @@ class SmartScraperMultiGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):

self.max_results = config.get("max_results", 3)

Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/speech_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -47,7 +48,7 @@ class SpeechGraph(AbstractGraph):
... {"llm": {"model": "gpt-3.5-turbo"}}
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "url" if source.startswith("http") else "local_dir"
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/xml_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -46,7 +47,7 @@ class XMLScraperGraph(AbstractGraph):
>>> result = xml_scraper.run()
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "xml" if source.endswith("xml") else "xml_dir"
Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .schemas import graph_schema
from .models_tokens import models_tokens
from .robots import robots_dictionary
from .generate_answer_node_prompts import template_chunks, template_chunks_with_schema, template_no_chunks, template_no_chunks_with_schema, template_merge
from .generate_answer_node_prompts import template_chunks, template_no_chunks, template_merge
from .generate_answer_node_csv_prompts import template_chunks_csv, template_no_chunks_csv, template_merge_csv
from .generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf, template_chunks_pdf_with_schema, template_no_chunks_pdf_with_schema
from .generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf
from .generate_answer_node_omni_prompts import template_chunks_omni, template_no_chunk_omni, template_merge_omni
Loading

0 comments on commit 376f758

Please sign in to comment.