From c5b480e73912fb25c211683e396e964845b74d55 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Wed, 28 Sep 2022 00:23:40 +0200 Subject: [PATCH] happy path works --- src/App.tsx | 18 +++++--- src/ai/RoboflowObjectDetector.ts | 19 ++++++-- src/logic/actions/AIObjectDetectionActions.ts | 43 +++++++++++++------ src/logic/initializer/AppInitializer.ts | 24 +++++------ src/store/selectors/AISelector.ts | 4 ++ .../EditorTopNavigationBar.tsx | 7 ++- .../LoadRoboflowModelPopup.tsx | 30 ++++++++++--- 7 files changed, 101 insertions(+), 44 deletions(-) diff --git a/src/App.tsx b/src/App.tsx index b7705d05..d8bfd5c7 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -17,11 +17,14 @@ import NotificationsView from './views/NotificationsView/NotificationsView'; interface IProps { projectType: ProjectType; windowSize: ISize; - ObjectDetectorLoaded: boolean; - PoseDetectionLoaded: boolean; + objectDetectorLoaded: boolean; + poseDetectionLoaded: boolean; + roboflowJSObjectDetectorLoaded: boolean; } -const App: React.FC = ({projectType, windowSize, ObjectDetectorLoaded, PoseDetectionLoaded}) => { +const App: React.FC = ( + {projectType, windowSize, objectDetectorLoaded, poseDetectionLoaded, roboflowJSObjectDetectorLoaded} +) => { const selectRoute = () => { if (!!PlatformModel.mobileDeviceData.manufacturer && !!PlatformModel.mobileDeviceData.os) return ; @@ -36,8 +39,10 @@ const App: React.FC = ({projectType, windowSize, ObjectDetectorLoaded, P } }; + const isAILoaded = objectDetectorLoaded || poseDetectionLoaded || roboflowJSObjectDetectorLoaded + return ( -
{selectRoute()} @@ -50,8 +55,9 @@ const App: React.FC = ({projectType, windowSize, ObjectDetectorLoaded, P const mapStateToProps = (state: AppState) => ({ projectType: state.general.projectData.type, windowSize: state.general.windowSize, - ObjectDetectorLoaded: state.ai.isObjectDetectorLoaded, - PoseDetectionLoaded: state.ai.isPoseDetectorLoaded + objectDetectorLoaded: state.ai.isObjectDetectorLoaded, + poseDetectionLoaded: state.ai.isPoseDetectorLoaded, + roboflowJSObjectDetectorLoaded: state.ai.isRoboflowJSObjectDetectorLoaded }); export default connect( diff --git a/src/ai/RoboflowObjectDetector.ts b/src/ai/RoboflowObjectDetector.ts index 724f7e99..06b11e77 100644 --- a/src/ai/RoboflowObjectDetector.ts +++ b/src/ai/RoboflowObjectDetector.ts @@ -1,4 +1,5 @@ -import * as express from 'roboflow'; +// import * as express from 'roboflow'; +import {max} from "lodash"; export interface DetectedObject { @@ -36,9 +37,21 @@ export class RoboflowObjectDetector { RoboflowObjectDetector.model .detect(image) - .then((predictions: DetectedObject[]) => { + .then((predictions) => { + const processedPredictions: DetectedObject[] = predictions.map((raw) => { + return { + bbox: [ + max([raw.bbox.x - raw.bbox.width / 2, 0]), + max([raw.bbox.y - raw.bbox.height / 2, 0]), + raw.bbox.width, + raw.bbox.height + ], + class: raw.class, + score: raw.confidence + } + }) if (callback) { - callback(predictions) + callback(processedPredictions) } }) .catch((error) => { diff --git a/src/logic/actions/AIObjectDetectionActions.ts b/src/logic/actions/AIObjectDetectionActions.ts index 3cf83eb0..fc0b4361 100644 --- a/src/logic/actions/AIObjectDetectionActions.ts +++ b/src/logic/actions/AIObjectDetectionActions.ts @@ -13,6 +13,7 @@ import {PopupWindowType} from '../../data/enums/PopupWindowType'; import {updateActivePopupType} from '../../store/general/actionCreators'; import {AISelector} from '../../store/selectors/AISelector'; import {AIActions} from './AIActions'; +import {RoboflowObjectDetector} from '../../ai/RoboflowObjectDetector'; export class AIObjectDetectionActions { public static detectRectsForActiveImage(): void { @@ -21,22 +22,36 @@ export class AIObjectDetectionActions { } public static detectRects(imageId: string, image: HTMLImageElement): void { - if (LabelsSelector.getImageDataById(imageId).isVisitedByObjectDetector || !AISelector.isAIObjectDetectorModelLoaded()) + if (LabelsSelector.getImageDataById(imageId).isVisitedByObjectDetector) return; - store.dispatch(updateActivePopupType(PopupWindowType.LOADER)); - ObjectDetector.predict(image, (predictions: DetectedObject[]) => { - const suggestedLabelNames = AIObjectDetectionActions.extractNewSuggestedLabelNames(LabelsSelector.getLabelNames(), predictions); - const rejectedLabelNames = AISelector.getRejectedSuggestedLabelList(); - const newlySuggestedNames = AIActions.excludeRejectedLabelNames(suggestedLabelNames, rejectedLabelNames); - if (newlySuggestedNames.length > 0) { - store.dispatch(updateSuggestedLabelList(newlySuggestedNames)); - store.dispatch(updateActivePopupType(PopupWindowType.SUGGEST_LABEL_NAMES)); - } else { - store.dispatch(updateActivePopupType(null)); - } - AIObjectDetectionActions.saveRectPredictions(imageId, predictions); - }) + if (AISelector.isAIObjectDetectorModelLoaded()) { + store.dispatch(updateActivePopupType(PopupWindowType.LOADER)); + ObjectDetector.predict(image, (predictions: DetectedObject[]) => { + AIObjectDetectionActions.handleObjectDetectorResults(imageId, predictions); + }) + } + + if (AISelector.isAIRoboflowJSObjectDetectorModelLoaded()) { + store.dispatch(updateActivePopupType(PopupWindowType.LOADER)); + RoboflowObjectDetector.predict(image, (predictions: DetectedObject[]) => { + AIObjectDetectionActions.handleObjectDetectorResults(imageId, predictions); + }) + } + } + + private static handleObjectDetectorResults = (imageId: string, predictions: DetectedObject[]) => { + const suggestedLabelNames = AIObjectDetectionActions + .extractNewSuggestedLabelNames(LabelsSelector.getLabelNames(), predictions); + const rejectedLabelNames = AISelector.getRejectedSuggestedLabelList(); + const newlySuggestedNames = AIActions.excludeRejectedLabelNames(suggestedLabelNames, rejectedLabelNames); + if (newlySuggestedNames.length > 0) { + store.dispatch(updateSuggestedLabelList(newlySuggestedNames)); + store.dispatch(updateActivePopupType(PopupWindowType.SUGGEST_LABEL_NAMES)); + } else { + store.dispatch(updateActivePopupType(null)); + } + AIObjectDetectionActions.saveRectPredictions(imageId, predictions); } public static saveRectPredictions(imageId: string, predictions: DetectedObject[]) { diff --git a/src/logic/initializer/AppInitializer.ts b/src/logic/initializer/AppInitializer.ts index 9e75184f..9d214c30 100644 --- a/src/logic/initializer/AppInitializer.ts +++ b/src/logic/initializer/AppInitializer.ts @@ -1,11 +1,11 @@ -import {updateWindowSize} from "../../store/general/actionCreators"; -import {ContextManager} from "../context/ContextManager"; -import {store} from "../../index"; -import {PlatformUtil} from "../../utils/PlatformUtil"; -import {PlatformModel} from "../../staticModels/PlatformModel"; -import {EventType} from "../../data/enums/EventType"; -import {GeneralSelector} from "../../store/selectors/GeneralSelector"; -import {EnvironmentUtil} from "../../utils/EnvironmentUtil"; +import {updateWindowSize} from '../../store/general/actionCreators'; +import {ContextManager} from '../context/ContextManager'; +import {store} from '../../index'; +import {PlatformUtil} from '../../utils/PlatformUtil'; +import {PlatformModel} from '../../staticModels/PlatformModel'; +import {EventType} from '../../data/enums/EventType'; +import {GeneralSelector} from '../../store/selectors/GeneralSelector'; +import {EnvironmentUtil} from '../../utils/EnvironmentUtil'; export class AppInitializer { public static inti():void { @@ -37,11 +37,7 @@ export class AppInitializer { }; private static disableUnwantedKeyBoardBehaviour = (event: KeyboardEvent) => { - if (PlatformModel.isMac && event.metaKey) { - event.preventDefault(); - } - - if (["=", "+", "-"].includes(event.key)) { + if (['=', '+', '-'].includes(event.key)) { if (event.ctrlKey || (PlatformModel.isMac && event.metaKey)) { event.preventDefault(); } @@ -61,4 +57,4 @@ export class AppInitializer { PlatformModel.isSafari = PlatformUtil.isSafari(userAgent); PlatformModel.isFirefox = PlatformUtil.isFirefox(userAgent); }; -} \ No newline at end of file +} diff --git a/src/store/selectors/AISelector.ts b/src/store/selectors/AISelector.ts index bf3eec34..ccb12134 100644 --- a/src/store/selectors/AISelector.ts +++ b/src/store/selectors/AISelector.ts @@ -13,6 +13,10 @@ export class AISelector { return store.getState().ai.isObjectDetectorLoaded; } + public static isAIRoboflowJSObjectDetectorModelLoaded(): boolean { + return store.getState().ai.isRoboflowJSObjectDetectorLoaded; + } + public static isAIPoseDetectorModelLoaded(): boolean { return store.getState().ai.isPoseDetectorLoaded; } diff --git a/src/views/EditorView/EditorTopNavigationBar/EditorTopNavigationBar.tsx b/src/views/EditorView/EditorTopNavigationBar/EditorTopNavigationBar.tsx index a4468921..57a911cc 100644 --- a/src/views/EditorView/EditorTopNavigationBar/EditorTopNavigationBar.tsx +++ b/src/views/EditorView/EditorTopNavigationBar/EditorTopNavigationBar.tsx @@ -102,6 +102,10 @@ const EditorTopNavigationBar: React.FC = ( updateCrossHairVisibleStatusAction(!crossHairVisible); }; + const showAIButtons = (activeLabelType === LabelType.RECT && AISelector.isAIObjectDetectorModelLoaded()) || + (activeLabelType === LabelType.RECT && AISelector.isAIRoboflowJSObjectDetectorModelLoaded()) || + (activeLabelType === LabelType.POINT && AISelector.isAIPoseDetectorModelLoaded()) + return (
@@ -174,8 +178,7 @@ const EditorTopNavigationBar: React.FC = ( ) }
- {((activeLabelType === LabelType.RECT && AISelector.isAIObjectDetectorModelLoaded()) || - (activeLabelType === LabelType.POINT && AISelector.isAIPoseDetectorModelLoaded())) &&
+ {showAIButtons &&
{ getButtonWithTooltip( 'accept-all', diff --git a/src/views/PopupView/LoadRoboflowModelPopup/LoadRoboflowModelPopup.tsx b/src/views/PopupView/LoadRoboflowModelPopup/LoadRoboflowModelPopup.tsx index 8d3dd1fb..b02505c4 100644 --- a/src/views/PopupView/LoadRoboflowModelPopup/LoadRoboflowModelPopup.tsx +++ b/src/views/PopupView/LoadRoboflowModelPopup/LoadRoboflowModelPopup.tsx @@ -2,13 +2,19 @@ import React, {useState} from 'react'; import './LoadRoboflowModelPopup.scss'; import {GenericYesNoPopup} from '../GenericYesNoPopup/GenericYesNoPopup'; import {PopupActions} from '../../../logic/actions/PopupActions'; -import {updateActivePopupType} from '../../../store/general/actionCreators'; import {AppState} from '../../../store'; import {connect} from 'react-redux'; import {TextField} from '@mui/material'; import {styled} from '@mui/system'; import {Settings} from '../../../settings/Settings'; import {RoboflowObjectDetector} from '../../../ai/RoboflowObjectDetector'; +import {updateActiveLabelType} from '../../../store/labels/actionCreators'; +import {LabelsActionTypes} from '../../../store/labels/types'; +import {LabelType} from '../../../data/enums/LabelType'; +import {updateRoboflowJSObjectDetectorStatus} from '../../../store/ai/actionCreators'; +import {AIActionTypes} from '../../../store/ai/types'; +import {LabelsSelector} from '../../../store/selectors/LabelsSelector'; +import {AIObjectDetectionActions} from '../../../logic/actions/AIObjectDetectionActions'; const StyledTextField = styled(TextField)({ '& .MuiInputBase-root': { @@ -32,12 +38,25 @@ const StyledTextField = styled(TextField)({ }); -const LoadRoboflowModelPopup: React.FC = () => { +interface IProps { + updateActiveLabelTypeAction: (activeLabelType: LabelType) => LabelsActionTypes, + updateRoboflowJSObjectDetectorStatusAction: (isRoboflowJSObjectDetectorLoaded: boolean) => AIActionTypes +} + +const LoadRoboflowModelPopup: React.FC = ({ + updateActiveLabelTypeAction, updateRoboflowJSObjectDetectorStatusAction + }) => { const [publishableKey, setPublishableKey] = useState(''); const [modelId, setModelId] = useState(''); const [modelVersion, setModelVersion] = useState(1); const onModelLoadSuccess = () => { + updateRoboflowJSObjectDetectorStatusAction(true) + updateActiveLabelTypeAction(LabelType.RECT) + const activeLabelType: LabelType = LabelsSelector.getActiveLabelType(); + if (activeLabelType === LabelType.RECT) { + AIObjectDetectionActions.detectRectsForActiveImage(); + } PopupActions.close(); } @@ -69,7 +88,7 @@ const LoadRoboflowModelPopup: React.FC = () => { { { } const mapDispatchToProps = { - updateActivePopupTypeAction: updateActivePopupType, + updateRoboflowJSObjectDetectorStatusAction: updateRoboflowJSObjectDetectorStatus, + updateActiveLabelTypeAction: updateActiveLabelType, }; const mapStateToProps = (state: AppState) => ({});