Skip to content

Commit

Permalink
code cleanup and removal of comments (Azure-Samples#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgbradley1 authored Jun 27, 2024
1 parent 9f64847 commit a3fa9d9
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 81 deletions.
2 changes: 1 addition & 1 deletion backend/src/aks-batch-job-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ kind: Job
metadata:
name: PLACEHOLDER
spec:
ttlSecondsAfterFinished: 120
ttlSecondsAfterFinished: 0
backoffLimit: 6
template:
metadata:
Expand Down
122 changes: 51 additions & 71 deletions backend/src/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ async def setup_indexing_pipeline(
detail=f"Data container '{storage_name}' does not exist.",
)

# check for prompts
entity_extraction_prompt_content = (
entity_extraction_prompt.file.read().decode("utf-8")
if entity_extraction_prompt
else None
)
community_report_prompt_content = (
community_report_prompt.file.read().decode("utf-8")
if community_report_prompt
else None
)
summarize_descriptions_prompt_content = (
summarize_descriptions_prompt.file.read().decode("utf-8")
if summarize_descriptions_prompt
else None
)

# check for existing index job
# it is okay if job doesn't exist, but if it does,
# it must not be scheduled or running
Expand All @@ -117,46 +134,29 @@ async def setup_indexing_pipeline(
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
if PipelineJobState(existing_job.status) == PipelineJobState.FAILED:
_delete_k8s_job(f"indexing-job-{sanitized_index_name}", "graphrag")

# reset the job to scheduled state
existing_job.status = PipelineJobState.SCHEDULED
existing_job.percent_complete = 0
existing_job.progress = ""
existing_job.all_workflows = existing_job.completed_workflows = (
existing_job.failed_workflows
# reset the pipeline job details
existing_job._status = PipelineJobState.SCHEDULED
existing_job._percent_complete = 0
existing_job._progress = ""
existing_job._all_workflows = existing_job._completed_workflows = (
existing_job._failed_workflows
) = []
existing_job.entity_extraction_prompt = None
existing_job.community_report_prompt = None
existing_job.summarize_descriptions_prompt = None

# create or update state in cosmos db
entity_extraction_prompt_content = (
entity_extraction_prompt.file.read().decode("utf-8")
if entity_extraction_prompt
else None
)
community_report_prompt_content = (
community_report_prompt.file.read().decode("utf-8")
if community_report_prompt
else None
)
summarize_descriptions_prompt_content = (
summarize_descriptions_prompt.file.read().decode("utf-8")
if summarize_descriptions_prompt
else None
)
print(f"ENTITY EXTRACTION PROMPT:\n{entity_extraction_prompt_content}")
print(f"COMMUNITY REPORT PROMPT:\n{community_report_prompt_content}")
print(f"SUMMARIZE DESCRIPTIONS PROMPT:\n{summarize_descriptions_prompt_content}")
pipelinejob.create_item(
id=sanitized_index_name,
index_name=sanitized_index_name,
storage_name=sanitized_storage_name,
entity_extraction_prompt=entity_extraction_prompt_content,
community_report_prompt=community_report_prompt_content,
summarize_descriptions_prompt=summarize_descriptions_prompt_content,
status=PipelineJobState.SCHEDULED,
)
existing_job._entity_extraction_prompt = entity_extraction_prompt_content
existing_job._community_report_prompt = community_report_prompt_content
existing_job._summarize_descriptions_prompt = (
summarize_descriptions_prompt_content
)
existing_job.update_db()
else:
pipelinejob.create_item(
id=sanitized_index_name,
index_name=sanitized_index_name,
storage_name=sanitized_storage_name,
entity_extraction_prompt=entity_extraction_prompt_content,
community_report_prompt=community_report_prompt_content,
summarize_descriptions_prompt=summarize_descriptions_prompt_content,
status=PipelineJobState.SCHEDULED,
)

"""
At this point, we know:
Expand All @@ -167,7 +167,6 @@ async def setup_indexing_pipeline(
# update or create new item in container-store in cosmosDB
if not _blob_service_client.get_container_client(sanitized_index_name).exists():
_blob_service_client.create_container(sanitized_index_name)

container_store_client = get_database_container_client(
database_name="graphrag", container_name="container-store"
)
Expand Down Expand Up @@ -221,9 +220,7 @@ async def setup_indexing_pipeline(
)


async def _start_indexing_pipeline(
index_name: str
):
async def _start_indexing_pipeline(index_name: str):
# get sanitized name
sanitized_index_name = sanitize_name(index_name)

Expand Down Expand Up @@ -265,41 +262,29 @@ async def _start_indexing_pipeline(
)

# set prompts for entity extraction, community report, and summarize descriptions.
# an environment variable is set to the file path of the prompt
if pipeline_job.entity_extraction_prompt:
fname = "entity-extraction-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.entity_extraction_prompt)
os.environ["GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE"] = fname
# data["entity_extraction"]["prompt"] = fname
# else:
# data["entity_extraction"]["prompt"] = None
data["entity_extraction"]["prompt"] = fname
else:
data.pop("entity_extraction")
if pipeline_job.community_report_prompt:
fname = "community-report-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.community_report_prompt)
os.environ["GRAPHRAG_COMMUNITY_REPORT_PROMPT_FILE"] = fname
# data["community_reports"]["prompt"] = fname
# else:
# data["community_reports"]["prompt"] = None
data["community_reports"]["prompt"] = fname
else:
data.pop("community_reports")
if pipeline_job.summarize_descriptions_prompt:
fname = "summarize-descriptions-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.summarize_descriptions_prompt)
os.environ["GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE"] = fname
# data["summarize_descriptions"]["prompt"] = fname
# else:
# data["summarize_descriptions"]["prompt"] = None

# set placeholder values to None if they have not been set
# if data["entity_extraction"]["prompt"] == "PLACEHOLDER":
# data["entity_extraction"]["prompt"] = None
# if data["community_reports"]["prompt"] == "PLACEHOLDER":
# data["community_reports"]["prompt"] = None
# if data["summarize_descriptions"]["prompt"] == "PLACEHOLDER":
# data["summarize_descriptions"]["prompt"] = None

# generate the default pipeline from default parameters and override with custom settings
data["summarize_descriptions"]["prompt"] = fname
else:
data.pop("summarize_descriptions")

# generate the default pipeline and override with custom settings
parameters = create_graphrag_config(data, ".")
pipeline_config = create_pipeline_config(parameters, True)

Expand All @@ -316,11 +301,6 @@ async def _start_indexing_pipeline(
PipelineJobWorkflowCallbacks(pipeline_job)
)

# print("#################### PIPELINE JOB:")
# pprint(pipeline_job.dump_model())
print("#################### PIPELINE CONFIG:")
print(pipeline_config)

# run the pipeline
try:
async for workflow_result in run_pipeline_with_config(
Expand Down
14 changes: 7 additions & 7 deletions backend/src/api/pipeline-settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

# this yaml file serves as a configuration template for the graphrag indexing jobs
# some values are hardcoded while others will be dynamically set
# some values are hardcoded while others denoted by PLACEHOLDER will be dynamically set
input:
type: blob
file_type: text
Expand Down Expand Up @@ -74,14 +74,14 @@ embeddings:
url: $AI_SEARCH_URL
audience: $AI_SEARCH_AUDIENCE

# entity_extraction:
# prompt: PLACEHOLDER
entity_extraction:
prompt: PLACEHOLDER

# community_reports:
# prompt: PLACEHOLDER
community_reports:
prompt: PLACEHOLDER

# summarize_descriptions:
# prompt: PLACEHOLDER
summarize_descriptions:
prompt: PLACEHOLDER

snapshots:
graphml: True
4 changes: 2 additions & 2 deletions backend/src/reporting/reporter_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import os
from urllib.parse import urlparse

import requests
from datashaper import NoopWorkflowCallbacks, WorkflowCallbacks
from datashaper import WorkflowCallbacks

from src.reporting.load_reporter import load_pipeline_reporter_from_list
from src.reporting.typing import Reporters
Expand All @@ -32,6 +31,7 @@ def get_instance(cls) -> WorkflowCallbacks:
)
return cls._instance


def _is_valid_url(url: str) -> bool:
try:
result = urlparse(url)
Expand Down

0 comments on commit a3fa9d9

Please sign in to comment.