Skip to content

Commit

Permalink
fix: shallow copy config of create_embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
liaoliaojun committed Jun 15, 2024
1 parent c44b701 commit 62b372b
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,40 +333,41 @@ def _create_embedder(self, embedder_config: dict) -> object:
Raises:
KeyError: If the model is not supported.
"""
embedder_params = {**embedder_config}
if "model_instance" in embedder_config:
return embedder_config["model_instance"]
return embedder_params["model_instance"]
# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
elif "azure" in embedder_config["model"]:
if "openai" in embedder_params["model"]:
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
elif "azure" in embedder_params["model"]:
return AzureOpenAIEmbeddings()
elif "ollama" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("ollama/")[-1]
elif "ollama" in embedder_params["model"]:
embedder_params["model"] = embedder_params["model"].split("ollama/")[-1]
try:
models_tokens["ollama"][embedder_config["model"]]
models_tokens["ollama"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OllamaEmbeddings(**embedder_config)
elif "hugging_face" in embedder_config["model"]:
return OllamaEmbeddings(**embedder_params)
elif "hugging_face" in embedder_params["model"]:
try:
models_tokens["hugging_face"][embedder_config["model"]]
models_tokens["hugging_face"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
elif "gemini" in embedder_config["model"]:
return HuggingFaceHubEmbeddings(model=embedder_params["model"])
elif "gemini" in embedder_params["model"]:
try:
models_tokens["gemini"][embedder_config["model"]]
models_tokens["gemini"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
client = embedder_config.get("client", None)
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
elif "bedrock" in embedder_params["model"]:
embedder_params["model"] = embedder_params["model"].split("/")[-1]
client = embedder_params.get("client", None)
try:
models_tokens["bedrock"][embedder_config["model"]]
models_tokens["bedrock"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=client, model_id=embedder_config["model"])
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
else:
raise ValueError("Model provided by the configuration not supported")

Expand Down

0 comments on commit 62b372b

Please sign in to comment.