Skip to content

Commit

Permalink
Migrated oauth flow to db for google calendar (TransformerOptimus#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tarraann authored Jul 14, 2023
1 parent 5ce2dd6 commit 8433f97
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 111 deletions.
35 changes: 18 additions & 17 deletions gui/pages/Content/Toolkits/ToolkitWorkspace.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){
function getGoogleToken(client_data){
const client_id = client_data.client_id
const scope = 'https://www.googleapis.com/auth/calendar';
const redirect_uri = 'http://localhost:3000/api/oauth-calendar';
const redirect_uri = 'http://localhost:3000/api/google/oauth-tokens';
window.location.href = `https://accounts.google.com/o/oauth2/v2/auth?client_id=${client_id}&redirect_uri=${redirect_uri}&access_type=offline&response_type=code&scope=${scope}`;
}

Expand Down Expand Up @@ -68,24 +68,25 @@ export default function ToolkitWorkspace({toolkitDetails, internalId}){
};

const handleAuthenticateClick = async (toolkitName) => {
if(toolkitName === 'Google Calendar Toolkit') {
authenticateGoogleCred(toolkitDetails.id)
.then((response) => {
if (toolkitName === "Google Calendar Toolkit"){
authenticateGoogleCred(toolkitDetails.id)
.then((response) => {
localStorage.setItem("google_calendar_toolkit_id", toolkitDetails.id)
getGoogleToken(response.data);
})
.catch((error) => {
console.error('Error fetching data:', error);
});
} else if(toolkitName === 'Twitter Toolkit') {
})
.catch((error) => {
console.error('Error fetching data:', error);
});
}else if(toolkitName === "Twitter Toolkit"){
authenticateTwitterCred(toolkitDetails.id)
.then((response) => {
localStorage.setItem("twitter_toolkit_id", toolkitDetails.id)
getTwitterToken(response.data);
})
.catch((error) => {
console.error('Error fetching data: ', error);
});
}
.then((response) => {
localStorage.setItem("twitter_toolkit_id", toolkitDetails.id)
getTwitterToken(response.data);
})
.catch((error) => {
console.error('Error fetching data: ', error);
});
}
};

useEffect(() => {
Expand Down
37 changes: 25 additions & 12 deletions gui/pages/Dashboard/Content.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import AgentWorkspace from '../Content/Agents/AgentWorkspace';
import ToolkitWorkspace from '../Content/./Toolkits/ToolkitWorkspace';
import Toolkits from '../Content/./Toolkits/Toolkits';
import Settings from "./Settings/Settings";
import ApmDashboard from "../Content/APM/ApmDashboard";
import styles from './Dashboard.module.css';
import Image from "next/image";
import {EventBus} from "@/utils/eventBus";
import {getAgents, getLastActiveAgent, getToolKit, sendTwitterCreds} from "@/pages/api/DashboardService";
import { EventBus } from "@/utils/eventBus";
import {getAgents, getToolKit, getLastActiveAgent, sendTwitterCreds, sendGoogleCreds} from "@/pages/api/DashboardService";
import Market from "../Content/Marketplace/Market";
import AgentTemplatesList from '../Content/Agents/AgentTemplatesList';
import {useRouter} from 'next/router';
Expand Down Expand Up @@ -114,7 +113,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat

const selectTab = (element, index) => {
setSelectedTab(index);
if (element.contentType === "Toolkits") {
if(element.contentType === "Toolkits") {
setToolkitDetails(element);
}
};
Expand All @@ -139,16 +138,30 @@ export default function Content({env, selectedView, selectedProjectId, organisat

if (window.location.href.indexOf("twitter_creds") > -1) {
parsedParams["toolkit_id"] = localStorage.getItem("twitter_toolkit_id") || null;
if (window.location.href.indexOf("twitter_creds") > -1){
const toolkit_id = localStorage.getItem("twitter_toolkit_id") || null;
parsedParams["toolkit_id"] = toolkit_id;
const params = JSON.stringify(parsedParams)
sendTwitterCreds(params)
.then((response) => {
console.log("Authentication completed successfully");
})
.catch((error) => {
console.error("Error fetching data: ", error);
})
}
;
.then((response) => {
console.log("Authentication completed successfully");
})
.catch((error) => {
console.error("Error fetching data: ",error);
})
};
if (window.location.href.indexOf("google_calendar_creds") > -1){
const toolkit_id = localStorage.getItem("google_calendar_toolkit_id") || null;
var data = Object.keys(parsedParams)[0];
var params = JSON.parse(data)
sendGoogleCreds(params, toolkit_id)
.then((response) => {
console.log("Authentication completed successfully");
})
.catch((error) => {
console.error("Error fetching data: ", error);
})
};
}, [selectedTab]);

useEffect(() => {
Expand Down
2 changes: 1 addition & 1 deletion gui/pages/_app.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export default function App() {
const [loadingText, setLoadingText] = useState("Initializing SuperAGI");
const router = useRouter();
const [showMarketplace, setShowMarketplace] = useState(false);
const excludedKeys = ['repo_starred', 'popup_closed_time', 'twitter_toolkit_id', 'accessToken', 'agent_to_install', 'toolkit_to_install', 'myLayoutKey'];
const excludedKeys = ['repo_starred', 'popup_closed_time', 'twitter_toolkit_id', 'accessToken', 'agent_to_install', 'toolkit_to_install', 'google_calendar_toolkit_id', 'myLayoutKey'];

function fetchOrganisation(userId) {
getOrganisation(userId)
Expand Down
4 changes: 4 additions & 0 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ export const sendTwitterCreds = (twitter_creds) => {
return api.post(`/twitter/send_twitter_creds/${twitter_creds}`);
}

export const sendGoogleCreds = (google_creds, toolkit_id) => {
return api.post(`/google/send_google_creds/toolkit_id/${toolkit_id}`, google_creds);
}

export const fetchToolTemplateList = () => {
return api.get(`/toolkits/get/list?page=0`);
}
Expand Down
49 changes: 2 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from superagi.controllers.organisation import router as organisation_router
from superagi.controllers.project import router as project_router
from superagi.controllers.twitter_oauth import router as twitter_oauth_router
from superagi.controllers.google_oauth import router as google_oauth_router
from superagi.controllers.resources import router as resources_router
from superagi.controllers.tool import router as tool_router
from superagi.controllers.tool_config import router as tool_config_router
Expand Down Expand Up @@ -112,6 +113,7 @@
app.include_router(agent_execution_config, prefix="/agent_executions_configs")
app.include_router(analytics_router, prefix="/analytics")

app.include_router(google_oauth_router, prefix="/google")

# in production you can use Settings management
# from pydantic to get secret key from .env
Expand Down Expand Up @@ -340,45 +342,6 @@ def login(request: LoginRequest, Authorize: AuthJWT = Depends()):
# access_token = Authorize.create_access_token(subject=user_email)
# return access_token

@app.get('/oauth-calendar')
async def google_auth_calendar(code: str = Query(...), Authorize: AuthJWT = Depends()):
client_id = db.session.query(ToolConfig).filter(ToolConfig.key == "GOOGLE_CLIENT_ID").first()
client_id = client_id.value
client_secret = db.session.query(ToolConfig).filter(ToolConfig.key == "GOOGLE_CLIENT_SECRET").first()
client_secret = client_secret.value
token_uri = 'https://oauth2.googleapis.com/token'
scope = 'https://www.googleapis.com/auth/calendar'
params = {
'client_id': client_id,
'client_secret': client_secret,
'redirect_uri': "http://localhost:3000/api/oauth-calendar",
'scope': scope,
'grant_type': 'authorization_code',
'code': code,
'access_type': 'offline'
}
response = requests.post(token_uri, data=params)
response = response.json()
expire_time = datetime.utcnow() + timedelta(seconds=response['expires_in'])
expire_time = expire_time - timedelta(minutes=5)
response['expiry'] = expire_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
root_dir = superagi.config.config.get_config('RESOURCES_OUTPUT_ROOT_DIR')
file_name = "credential_token.pickle"
final_path = file_name
if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
final_path = root_dir + file_name
else:
final_path = os.getcwd() + "/" + file_name
try:
with open(final_path, mode="wb") as file:
pickle.dump(response, file)
except Exception as err:
return f"Error: {err}"
frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
return RedirectResponse(frontend_url)

@app.get('/github-login')
def github_login():
"""GitHub login"""
Expand Down Expand Up @@ -461,14 +424,6 @@ async def root(Authorize: AuthJWT = Depends()):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")


@app.get("/google/get_google_creds/toolkit_id/{toolkit_id}")
def get_google_calendar_tool_configs(toolkit_id: int):
google_calendar_config = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,
ToolConfig.key == "GOOGLE_CLIENT_ID").first()
return {
"client_id": google_calendar_config.value
}

@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
"""API to validate Open AI Key"""
Expand Down
72 changes: 72 additions & 0 deletions superagi/controllers/google_oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from fastapi import Depends, Query
from fastapi import APIRouter
from fastapi.responses import RedirectResponse
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from sqlalchemy.orm import sessionmaker

import superagi
import json
import requests
from datetime import datetime, timedelta
from superagi.models.db import connect_db
import http.client as http_client
from superagi.helper.auth import get_current_user
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens

router = APIRouter()

@router.get('/oauth-tokens')
async def google_auth_calendar(code: str = Query(...), Authorize: AuthJWT = Depends()):
client_id = db.session.query(ToolConfig).filter(ToolConfig.key == "GOOGLE_CLIENT_ID").first()
client_id = client_id.value
client_secret = db.session.query(ToolConfig).filter(ToolConfig.key == "GOOGLE_CLIENT_SECRET").first()
client_secret = client_secret.value
token_uri = 'https://oauth2.googleapis.com/token'
scope = 'https://www.googleapis.com/auth/calendar'
params = {
'client_id': client_id,
'client_secret': client_secret,
'redirect_uri': "http://localhost:3000/api/google/oauth-tokens",
'scope': scope,
'grant_type': 'authorization_code',
'code': code,
'access_type': 'offline'
}
response = requests.post(token_uri, data=params)
response = response.json()
expire_time = datetime.utcnow() + timedelta(seconds=response['expires_in'])
expire_time = expire_time - timedelta(minutes=5)
response['expiry'] = expire_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
response_data = json.dumps(response)
frontend_url = superagi.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
redirect_url_success = f"{frontend_url}/google_calendar_creds/?{response_data}"
return RedirectResponse(url=redirect_url_success)

@router.post("/send_google_creds/toolkit_id/{toolkit_id}")
def send_google_calendar_configs(google_creds: dict, toolkit_id: int, Authorize: AuthJWT = Depends()):
engine = connect_db()
Session = sessionmaker(bind=engine)
session = Session()
current_user = get_current_user()
user_id = current_user.id
toolkit = db.session.query(Toolkit).filter(Toolkit.id == toolkit_id).first()
google_creds = json.dumps(google_creds)
print(google_creds)
tokens = OauthTokens().add_or_update(session, toolkit_id, user_id, toolkit.organisation_id, "GOOGLE_CALENDAR_OAUTH_TOKENS", google_creds)
if tokens:
success = True
else:
success = False
return success


@router.get("/get_google_creds/toolkit_id/{toolkit_id}")
def get_google_calendar_tool_configs(toolkit_id: int):
google_calendar_config = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id,
ToolConfig.key == "GOOGLE_CLIENT_ID").first()
return {
"client_id": google_calendar_config.value
}
51 changes: 27 additions & 24 deletions superagi/helper/google_calendar_creds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle
import os
import json
import ast
from datetime import datetime
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import Flow
Expand All @@ -9,32 +10,27 @@
from googleapiclient.discovery import build
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from sqlalchemy.orm import Session
from superagi.models.tool_config import ToolConfig
from superagi.resource_manager.file_manager import FileManager
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens

class GoogleCalendarCreds:

def __init__(self, session: Session):
self.session = session

def get_credentials(self, toolkit_id):
file_name = "credential_token.pickle"
root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')
file_path = file_name
if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
file_path = root_dir + file_name
else:
file_path = os.getcwd() + "/" + file_name
if os.path.exists(file_path):
engine = connect_db()
Session = sessionmaker(bind=engine)
session = Session()
resource_manager: FileManager = None
with open(file_path,'rb') as file:
creds = pickle.load(file)
if isinstance(creds, str):
creds = json.loads(creds)
expire_time = datetime.strptime(creds["expiry"], "%Y-%m-%dT%H:%M:%S.%fZ")
google_creds = session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id).all()
toolkit = self.session.query(Toolkit).filter(Toolkit.id == toolkit_id).first()
organisation_id = toolkit.organisation_id
google_creds = self.session.query(OauthTokens).filter(OauthTokens.toolkit_id == toolkit_id, OauthTokens.organisation_id == organisation_id).first()
if google_creds:
user_id = google_creds.user_id
final_creds = json.loads(google_creds.value)
final_creds["refresh_token"] = self.fix_refresh_token(final_creds["refresh_token"])
expire_time = datetime.strptime(final_creds["expiry"], "%Y-%m-%dT%H:%M:%S.%fZ")
google_creds = self.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit_id).all()
client_id = ""
client_secret = ""
for credentials in google_creds:
Expand All @@ -46,15 +42,22 @@ def get_credentials(self, toolkit_id):
creds = Credentials.from_authorized_user_info(info={
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": creds["refresh_token"],
"refresh_token": final_creds["refresh_token"],
"scopes": "https://www.googleapis.com/auth/calendar"
})
if expire_time < datetime.utcnow():
if expire_time > datetime.utcnow():
creds.refresh(Request())
creds_json = creds.to_json()
resource_manager.write_file(file_name, creds_json)
tokens = OauthTokens().add_or_update(self.session, toolkit_id, user_id, toolkit.organisation_id, "GOOGLE_CALENDAR_OAUTH_TOKENS", str(creds_json))
else:
return {"success": False}
service = build('calendar','v3',credentials=creds)
return {"success": True, "service": service}


def fix_refresh_token(self, refresh_token):
if refresh_token.count('/') == 1:
# Find the position of '/'
slash_index = refresh_token.index('/')
# Insert one more '/' at the position
refresh_token = refresh_token[:slash_index+1] + '/' + refresh_token[slash_index+1:]
return refresh_token
5 changes: 4 additions & 1 deletion superagi/tools/google_calendar/create_calendar_event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Type
from pydantic import BaseModel, Field
from superagi.tools.base_tool import BaseTool
from sqlalchemy.orm import sessionmaker
from superagi.models.db import connect_db
from superagi.helper.google_calendar_creds import GoogleCalendarCreds
from superagi.helper.calendar_date import CalendarDate

Expand All @@ -20,8 +22,9 @@ class CreateEventCalendarTool(BaseTool):
description: str = "Create an event for Google Calendar"

def _execute(self, event_name: str, description: str, attendees: list, start_date: str = 'None', start_time: str = 'None', end_date: str = 'None', end_time: str = 'None', location: str = 'None'):
session = self.toolkit_config.session
toolkit_id = self.toolkit_config.toolkit_id
service = GoogleCalendarCreds().get_credentials(toolkit_id)
service = GoogleCalendarCreds(session).get_credentials(toolkit_id)
if service["success"]:
service = service["service"]
else:
Expand Down
Loading

0 comments on commit 8433f97

Please sign in to comment.