Skip to content

Commit

Permalink
Better support for image generation capable models (onyx-dot-app#2725)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves authored Oct 8, 2024
1 parent aa69fe7 commit 21a3921
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 37 deletions.
20 changes: 7 additions & 13 deletions web/src/app/admin/assistants/AssistantEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { getDisplayNameForModel } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import { checkLLMSupportsImageOutput, destructureValue } from "@/lib/llm/utils";
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { checkUserIsNoAuthUser } from "@/lib/user";

Expand Down Expand Up @@ -349,12 +349,9 @@ export function AssistantEditor({

if (imageGenerationToolEnabled) {
if (
!checkLLMSupportsImageOutput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
// model must support image input for image generation
// to work
!checkLLMSupportsImageInput(
values.llm_model_version_override || defaultModelName || ""
)
) {
Expand Down Expand Up @@ -469,12 +466,9 @@ export function AssistantEditor({
: false;
}

const currentLLMSupportsImageOutput = checkLLMSupportsImageOutput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
// model must support image input for image generation
// to work
const currentLLMSupportsImageOutput = checkLLMSupportsImageInput(
values.llm_model_version_override || defaultModelName || ""
);

Expand Down
45 changes: 21 additions & 24 deletions web/src/lib/llm/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,45 +68,42 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"gpt-4-vision-preview",
"gpt-4-turbo",
"gpt-4-1106-vision-preview",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-vision-preview",
"gpt-4-turbo",
"gpt-4-1106-vision-preview",
// standard claude names
"claude-3-5-sonnet-20240620",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
// claude names with AWS Bedrock Suffix
"claude-3-opus-20240229-v1:0",
"claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620-v1:0",
// claude names with full AWS Bedrock names
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
];

export function checkLLMSupportsImageInput(model: string) {
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
// Original exact match check
const exactMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
(modelName) => modelName === model
);
}

const MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT = [
["openai", "gpt-4o"],
["openai", "gpt-4o-mini"],
["openai", "gpt-4-vision-preview"],
["openai", "gpt-4-turbo"],
["openai", "gpt-4-1106-vision-preview"],
["azure", "gpt-4o"],
["azure", "gpt-4o-mini"],
["azure", "gpt-4-vision-preview"],
["azure", "gpt-4-turbo"],
["azure", "gpt-4-1106-vision-preview"],
];
if (exactMatch) {
return true;
}

export function checkLLMSupportsImageOutput(provider: string, model: string) {
return MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT.some(
(modelProvider) =>
modelProvider[0] === provider && modelProvider[1] === model
);
// Additional check for the last part of the model name
const modelParts = model.split(/[/.]/);
const lastPart = modelParts[modelParts.length - 1];

return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
const modelNameParts = modelName.split(/[/.]/);
const modelNameLastPart = modelNameParts[modelNameParts.length - 1];
return modelNameLastPart === lastPart;
});
}

export const structureValue = (
Expand Down

0 comments on commit 21a3921

Please sign in to comment.