From f38c4aef1986ef0e4069233af2ea1d9a925eb4e9 Mon Sep 17 00:00:00 2001 From: Kalki <97698934+jedan2506@users.noreply.github.com> Date: Fri, 8 Sep 2023 17:08:21 +0530 Subject: [PATCH] Models marketplace updates (#1214) Model marketplace loader --- gui/pages/Content/Models/MarketModels.js | 11 ++++++++--- gui/pages/Content/Models/ModelTemplate.js | 17 +++++++++++++++-- superagi/controllers/models_controller.py | 4 +++- superagi/controllers/types/models_types.py | 14 ++++++++++++++ superagi/models/models.py | 9 ++++++--- 5 files changed, 46 insertions(+), 9 deletions(-) create mode 100644 superagi/controllers/types/models_types.py diff --git a/gui/pages/Content/Models/MarketModels.js b/gui/pages/Content/Models/MarketModels.js index 578eed96b..383b48be7 100644 --- a/gui/pages/Content/Models/MarketModels.js +++ b/gui/pages/Content/Models/MarketModels.js @@ -18,18 +18,23 @@ export default function MarketModels(){ if (window.location.href.toLowerCase().includes('marketplace')) { axios.get('https://app.superagi.com/api/models_controller/get/models_details') .then((response) => { - console.log(response.data) setModelTemplates(response.data) }) } else { fetchMarketPlaceModel().then((response) => { - console.log(response.data) setModelTemplates(response.data) }) } },[]) + useEffect(() => { + if(modelTemplates.length > 0) + setIsLoading(true) + else + setIsLoading(false) + }, [modelTemplates]) + function handleTemplateClick(item) { const contentType = 'model_template'; EventBus.emit('openTemplateDetails', {item, contentType}); @@ -38,7 +43,7 @@ export default function MarketModels(){ return(
- {!isLoading ?
+ {isLoading ?
{modelTemplates.length > 0 ?
{modelTemplates.map((item) => (
handleTemplateClick(item)}>
{item.model_name && item.model_name.includes('/') ? item.model_name.split('/')[1] : item.model_name}
diff --git a/gui/pages/Content/Models/ModelTemplate.js b/gui/pages/Content/Models/ModelTemplate.js index d04f58cd5..8535beced 100644 --- a/gui/pages/Content/Models/ModelTemplate.js +++ b/gui/pages/Content/Models/ModelTemplate.js @@ -10,6 +10,19 @@ export default function ModelTemplate({env, template}){ EventBus.emit('goToMarketplace', {}); } + function handleInstallClick() { + if (window.location.href.toLowerCase().includes('marketplace')) { + if (env === 'PROD') { + window.open(`https://app.superagi.com/`, '_self'); + } else { + window.location.href = '/'; + } + } + else { + setIsInstalled(true) + } + } + return (
isInstalled ? setIsInstalled(false) : handleBackClick()}> @@ -20,7 +33,7 @@ export default function ModelTemplate({env, template}){
{template.model_name} by {template.model_name.includes('/') ? template.model_name.split('/')[0] : template.provider} - @@ -39,7 +52,7 @@ export default function ModelTemplate({env, template}){ Updated At {getFormattedDate(template.updated_at)}
-
+
):( )} diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index 1bd2eeb9d..cda0dd828 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -5,6 +5,7 @@ from superagi.models.models import Models from superagi.models.models_config import ModelsConfig from superagi.config.config import get_config +from superagi.controllers.types.models_types import ModelsTypes from fastapi_sqlalchemy import db import logging from pydantic import BaseModel @@ -153,5 +154,6 @@ def get_models_details(page: int = 0): if page < 0: page = 0 marketplace_models = Models.fetch_marketplace_list(page) - marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id) + marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id, + ModelsTypes.MARKETPLACE.value) return marketplace_models_with_install \ No newline at end of file diff --git a/superagi/controllers/types/models_types.py b/superagi/controllers/types/models_types.py new file mode 100644 index 000000000..0a3b99ce1 --- /dev/null +++ b/superagi/controllers/types/models_types.py @@ -0,0 +1,14 @@ +from enum import Enum + +class ModelsTypes(Enum): + MARKETPLACE = "Marketplace" + CUSTOM = "Custom" + + @classmethod + def get_models_types(cls, model_type): + if model_type is None: + raise ValueError("Queue status type cannot be None.") + model_type = model_type.upper() + if model_type in cls.__members__: + return cls[model_type] + raise ValueError(f"{model_type} is not a valid storage name.") diff --git a/superagi/models/models.py b/superagi/models/models.py index ced271364..ccd5cdcf5 100644 --- a/superagi/models/models.py +++ b/superagi/models/models.py @@ -2,7 +2,7 @@ from sqlalchemy.sql import func from typing import List, Dict, Union from superagi.models.base_model import DBBaseModel -from superagi.llms.openai import OpenAi +from superagi.controllers.types.models_types import ModelsTypes from superagi.helper.encyption_helper import decrypt_data import requests, logging @@ -64,7 +64,7 @@ def fetch_marketplace_list(cls, page): return [] @classmethod - def get_model_install_details(cls, session, marketplace_models, organisation_id): + def get_model_install_details(cls, session, marketplace_models, organisation_id, type=ModelsTypes.CUSTOM.value): from superagi.models.models_config import ModelsConfig installed_models = session.query(Models).filter(Models.org_id == organisation_id).all() model_counts_dict = dict( @@ -74,7 +74,10 @@ def get_model_install_details(cls, session, marketplace_models, organisation_id) for model in marketplace_models: try: - model["is_installed"] = installed_models_dict.get(model["model_name"], False) + if type == ModelsTypes.MARKETPLACE.value: + model["is_installed"] = False + else: + model["is_installed"] = installed_models_dict.get(model["model_name"], False) model["installs"] = model_counts_dict.get(model["model_name"], 0) model["provider"] = session.query(ModelsConfig).filter( ModelsConfig.id == model["model_provider_id"]).first().provider