Skip to content

Commit

Permalink
Models marketplace updates (TransformerOptimus#1214)
Browse files Browse the repository at this point in the history
Model marketplace loader
  • Loading branch information
jedan2506 authored Sep 8, 2023
1 parent 4e0295d commit f38c4ae
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 9 deletions.
11 changes: 8 additions & 3 deletions gui/pages/Content/Models/MarketModels.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@ export default function MarketModels(){
if (window.location.href.toLowerCase().includes('marketplace')) {
axios.get('https://app.superagi.com/api/models_controller/get/models_details')
.then((response) => {
console.log(response.data)
setModelTemplates(response.data)
})
}
else {
fetchMarketPlaceModel().then((response) => {
console.log(response.data)
setModelTemplates(response.data)
})
}
},[])

useEffect(() => {
if(modelTemplates.length > 0)
setIsLoading(true)
else
setIsLoading(false)
}, [modelTemplates])

function handleTemplateClick(item) {
const contentType = 'model_template';
EventBus.emit('openTemplateDetails', {item, contentType});
Expand All @@ -38,7 +43,7 @@ export default function MarketModels(){
return(
<div id="market_models" className={showMarketplace ? 'ml_8' : 'ml_3'}>
<div className="w_100 overflowY_auto mxh_78vh">
{!isLoading ? <div>
{isLoading ? <div>
{modelTemplates.length > 0 ? <div className="marketplaceGrid">{modelTemplates.map((item) => (
<div className="market_containers cursor_pointer" key={item.id} onClick={() => handleTemplateClick(item)}>
<div>{item.model_name && item.model_name.includes('/') ? item.model_name.split('/')[1] : item.model_name}</div>
Expand Down
17 changes: 15 additions & 2 deletions gui/pages/Content/Models/ModelTemplate.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ export default function ModelTemplate({env, template}){
EventBus.emit('goToMarketplace', {});
}

function handleInstallClick() {
if (window.location.href.toLowerCase().includes('marketplace')) {
if (env === 'PROD') {
window.open(`https://app.superagi.com/`, '_self');
} else {
window.location.href = '/';
}
}
else {
setIsInstalled(true)
}
}

return (
<div id="model_template">
<div className="back_button mt_16 mb_16" onClick={() => isInstalled ? setIsInstalled(false) : handleBackClick()}>
Expand All @@ -20,7 +33,7 @@ export default function ModelTemplate({env, template}){
<div className="col_3 display_column_container padding_16">
<span className="text_20 color_white">{template.model_name}</span>
<span className="text_12 color_gray mt_4">by {template.model_name.includes('/') ? template.model_name.split('/')[0] : template.provider}</span>
<button className="primary_button w_100 mt_16" disabled={template.is_installed} onClick={() => setIsInstalled(true)}>
<button className="primary_button w_100 mt_16" disabled={template.is_installed} onClick={() => handleInstallClick()}>
<Image width={16} height={16} src={template.is_installed ? '/images/tick.svg' : '/images/marketplace_download.svg'} alt="download-icon" />
<span className="ml_8">{template.is_installed ? 'Installed' : 'Install'}</span>
</button>
Expand All @@ -39,7 +52,7 @@ export default function ModelTemplate({env, template}){
<span className="text_12 color_gray">Updated At</span>
<span className="text_12 color_white mt_8">{getFormattedDate(template.updated_at)}</span>
</div>
<div className="col_9 display_column_container padding_16 color_white" dangerouslySetInnerHTML={{ __html: template.model_features }} />
<div className="col_9 display_column_container padding_16 color_white text_12 lh_18" dangerouslySetInnerHTML={{ __html: template.model_features }} />
</div> ):(
<AddModelMarketPlace template={template} />
)}
Expand Down
4 changes: 3 additions & 1 deletion superagi/controllers/models_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from superagi.models.models import Models
from superagi.models.models_config import ModelsConfig
from superagi.config.config import get_config
from superagi.controllers.types.models_types import ModelsTypes
from fastapi_sqlalchemy import db
import logging
from pydantic import BaseModel
Expand Down Expand Up @@ -153,5 +154,6 @@ def get_models_details(page: int = 0):
if page < 0:
page = 0
marketplace_models = Models.fetch_marketplace_list(page)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id,
ModelsTypes.MARKETPLACE.value)
return marketplace_models_with_install
14 changes: 14 additions & 0 deletions superagi/controllers/types/models_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from enum import Enum

class ModelsTypes(Enum):
MARKETPLACE = "Marketplace"
CUSTOM = "Custom"

@classmethod
def get_models_types(cls, model_type):
if model_type is None:
raise ValueError("Queue status type cannot be None.")
model_type = model_type.upper()
if model_type in cls.__members__:
return cls[model_type]
raise ValueError(f"{model_type} is not a valid storage name.")
9 changes: 6 additions & 3 deletions superagi/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy.sql import func
from typing import List, Dict, Union
from superagi.models.base_model import DBBaseModel
from superagi.llms.openai import OpenAi
from superagi.controllers.types.models_types import ModelsTypes
from superagi.helper.encyption_helper import decrypt_data
import requests, logging

Expand Down Expand Up @@ -64,7 +64,7 @@ def fetch_marketplace_list(cls, page):
return []

@classmethod
def get_model_install_details(cls, session, marketplace_models, organisation_id):
def get_model_install_details(cls, session, marketplace_models, organisation_id, type=ModelsTypes.CUSTOM.value):
from superagi.models.models_config import ModelsConfig
installed_models = session.query(Models).filter(Models.org_id == organisation_id).all()
model_counts_dict = dict(
Expand All @@ -74,7 +74,10 @@ def get_model_install_details(cls, session, marketplace_models, organisation_id)

for model in marketplace_models:
try:
model["is_installed"] = installed_models_dict.get(model["model_name"], False)
if type == ModelsTypes.MARKETPLACE.value:
model["is_installed"] = False
else:
model["is_installed"] = installed_models_dict.get(model["model_name"], False)
model["installs"] = model_counts_dict.get(model["model_name"], 0)
model["provider"] = session.query(ModelsConfig).filter(
ModelsConfig.id == model["model_provider_id"]).first().provider
Expand Down

0 comments on commit f38c4ae

Please sign in to comment.