From 48381f1f78b2056d885376de8b6ff247cf87a99d Mon Sep 17 00:00:00 2001
From: Eduard van Valkenburg
Date: Sat, 24 Jun 2023 00:01:08 +0200
Subject: [PATCH] PowerBI: catch outdated token (#6634)
This adds just a small tweak to catch the error that says the token is
expired rather then retrying.
---
langchain/tools/powerbi/tool.py | 10 ++++++++++
langchain/utilities/powerbi.py | 4 ++--
2 files changed, 12 insertions(+), 2 deletions(-)
diff --git a/langchain/tools/powerbi/tool.py b/langchain/tools/powerbi/tool.py
index 2714e57eff725..85b181f906e3d 100644
--- a/langchain/tools/powerbi/tool.py
+++ b/langchain/tools/powerbi/tool.py
@@ -96,6 +96,11 @@ def _run(
logger.info("Query: %s", query)
pbi_result = self.powerbi.run(command=query)
result, error = self._parse_output(pbi_result)
+ if error is not None and "TokenExpired" in error:
+ self.session_cache[
+ tool_input
+ ] = "Authentication token expired or invalid, please try reauthenticate."
+ return self.session_cache[tool_input]
iterations = kwargs.get("iterations", 0)
if error and iterations < self.max_iterations:
@@ -140,6 +145,11 @@ async def _arun(
logger.info("Query: %s", query)
pbi_result = await self.powerbi.arun(command=query)
result, error = self._parse_output(pbi_result)
+ if error is not None and "TokenExpired" in error:
+ self.session_cache[
+ tool_input
+ ] = "Authentication token expired or invalid, please try reauthenticate."
+ return self.session_cache[tool_input]
iterations = kwargs.get("iterations", 0)
if error and iterations < self.max_iterations:
diff --git a/langchain/utilities/powerbi.py b/langchain/utilities/powerbi.py
index 16a6f379ccc7d..36505a234478c 100644
--- a/langchain/utilities/powerbi.py
+++ b/langchain/utilities/powerbi.py
@@ -226,7 +226,7 @@ async def arun(self, command: str) -> Any:
json=self._create_json_content(command),
timeout=10,
) as response:
- response_json = await response.json()
+ response_json = await response.json(content_type=response.content_type)
return response_json
async with aiohttp.ClientSession() as session:
async with session.post(
@@ -235,7 +235,7 @@ async def arun(self, command: str) -> Any:
json=self._create_json_content(command),
timeout=10,
) as response:
- response_json = await response.json()
+ response_json = await response.json(content_type=response.content_type)
return response_json