Skip to content

Commit

Permalink
gai: Scripting extensions improvements and cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
patniemeyer committed Dec 13, 2024
1 parent 66f0523 commit 7a3e504
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 80 deletions.
44 changes: 23 additions & 21 deletions gai-frontend/lib/chat/chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ class _ChatViewState extends State<ChatView> {
late final ProviderManager _providerManager;

// Models
final ModelManager _modelsState = ModelManager();
List<String> _selectedModelIds = [];
final ModelManager _modelManager = ModelManager();
List<String> _userSelectedModelIds = [];

List<ModelInfo> get _selectedModels =>
_modelsState.getModelsOrDefault(_selectedModelIds);
List<ModelInfo> get _userSelectedModels =>
_modelManager.getModelsOrDefault(_userSelectedModelIds);

// Account
// This should be wrapped up in a provider. See WIP in vpn app.
Expand All @@ -72,7 +72,7 @@ class _ChatViewState extends State<ChatView> {

// Init the provider manager
_providerManager = ProviderManager(
modelsState: _modelsState,
modelsState: _modelManager,
onProviderConnected: providerConnected,
onProviderDisconnected: providerDisconnected,
onChatMessage: _addChatMessage,
Expand All @@ -95,6 +95,8 @@ class _ChatViewState extends State<ChatView> {
url: 'lib/extensions/filter_example.js',
debugMode: true,
providerManager: _providerManager,
modelManager: _modelManager,
getUserSelectedModels: () => _userSelectedModels,
chatHistory: _chatHistory,
addChatMessageToUI: _addChatMessage,
);
Expand Down Expand Up @@ -149,7 +151,7 @@ class _ChatViewState extends State<ChatView> {
} else {
// Disconnects any existing provider connection
_providerManager.setAccountDetail(null);
_modelsState.clear();
_modelManager.clear();
}

_accountDetailNotifier.value =
Expand Down Expand Up @@ -219,19 +221,19 @@ class _ChatViewState extends State<ChatView> {
_chatHistory.addMessage(message);
});
scrollMessagesDown();
log('Chat history updated: ${_chatHistory.messages.length}, ${_chatHistory.messages}');
// log('Chat history updated: ${_chatHistory.messages.length}, ${_chatHistory.messages}');
}

void _updateSelectedModels(List<String> modelIds) {
setState(() {
if (_multiSelectMode) {
_selectedModelIds = modelIds;
_userSelectedModelIds = modelIds;
} else {
// In single-select mode, only keep the most recently selected model
_selectedModelIds = modelIds.isNotEmpty ? [modelIds.last] : [];
_userSelectedModelIds = modelIds.isNotEmpty ? [modelIds.last] : [];
}
});
log('Selected models updated to: $_selectedModelIds');
log('Selected models updated to: $_userSelectedModelIds');
}

void _popAccountDialog() {
Expand Down Expand Up @@ -305,16 +307,16 @@ class _ChatViewState extends State<ChatView> {
}

// Debug hack
if (_selectedModelIds.isEmpty &&
if (_userSelectedModelIds.isEmpty &&
ChatScripting.enabled &&
ChatScripting.instance.debugMode) {
setState(() {
_selectedModelIds = ['gpt-4o'];
_userSelectedModelIds = ['gpt-4o'];
});
}

// Validate the selected models
if (_selectedModelIds.isEmpty) {
if (_userSelectedModelIds.isEmpty) {
_addMessage(
ChatMessageSource.system,
_multiSelectMode
Expand All @@ -329,7 +331,7 @@ class _ChatViewState extends State<ChatView> {

// If we have a script selected allow it to handle the prompt
if (ChatScripting.enabled) {
ChatScripting.instance.sendUserPrompt(msg, _selectedModels);
ChatScripting.instance.sendUserPrompt(msg, _userSelectedModels);
} else {
_sendUserPromptDefaultBehavior(msg);
}
Expand All @@ -348,7 +350,7 @@ class _ChatViewState extends State<ChatView> {
// This strategy selects messages based on the isolated / party mode and sends them sequentially to each
// of the user-selected models allowing each model to see the previous responses.
Future<void> _sendChatHistoryToSelectedModels() async {
for (final modelId in _selectedModelIds) {
for (final modelId in _userSelectedModelIds) {
try {
// Filter messages based on conversation mode.
final selectedMessages = _partyMode
Expand Down Expand Up @@ -380,7 +382,7 @@ class _ChatViewState extends State<ChatView> {
chatResponse.message,
metadata: metadata,
modelId: modelId,
modelName: _modelsState.getModelOrDefaultNullable(modelId)?.name,
modelName: _modelManager.getModelOrDefaultNullable(modelId)?.name,
);
}

Expand Down Expand Up @@ -554,9 +556,9 @@ class _ChatViewState extends State<ChatView> {

// Model selector with loading state
ListenableBuilder(
listenable: _modelsState,
listenable: _modelManager,
builder: (context, _) {
if (_modelsState.isAnyLoading) {
if (_modelManager.isAnyLoading) {
return const SizedBox(
width: buttonHeight,
height: buttonHeight,
Expand All @@ -569,8 +571,8 @@ class _ChatViewState extends State<ChatView> {
}

return ModelSelectionButton(
models: _modelsState.allModels,
selectedModelIds: _selectedModelIds,
models: _modelManager.allModels,
selectedModelIds: _userSelectedModelIds,
updateModels: _updateSelectedModels,
multiSelectMode: _multiSelectMode,
);
Expand Down Expand Up @@ -604,7 +606,7 @@ class _ChatViewState extends State<ChatView> {
setState(() {
_multiSelectMode = !_multiSelectMode;
// Reset selections when toggling modes
_selectedModelIds = [];
_userSelectedModelIds = [];
});
},
onPartyModeChanged: () {
Expand Down
8 changes: 4 additions & 4 deletions gai-frontend/lib/chat/scripting/chat_bindings_js.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ external JSAny? evaluateJS(String jsCode);

// Setter to install the chatHistory variable in JS
// let chatHistory: ReadonlyArray<ChatMessage>;
@JS('chatHistory')
external set chatHistoryJS(JSArray value);
@JS('getChatHistory')
external set getChatHistoryJS(JSFunction value);

// Setter to install the userSelectedModels variable in JS
// List of ModelInfo user-selected models, read-only
@JS('userSelectedModels')
external set userSelectedModelsJS(JSArray value);
@JS('getUserSelectedModels')
external set getUserSelectedModelsJS(JSFunction value);

// Setter to install the sendMessagesToModel callback function in JS
// Send a list of ChatMessage to a model for inference
Expand Down
53 changes: 30 additions & 23 deletions gai-frontend/lib/chat/scripting/chat_scripting.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'package:orchid/chat/chat_history.dart';
import 'package:orchid/chat/chat_message.dart';
import 'package:orchid/chat/model.dart';
import 'package:orchid/chat/model_manager.dart';
import 'package:orchid/chat/provider_connection.dart';
import 'package:orchid/chat/provider_manager.dart';
import 'package:orchid/gui-orchid/lib/orchid/orchid.dart';
Expand All @@ -23,9 +24,11 @@ class ChatScripting {

static bool get enabled => _instance != null;

// Scripting config
// Scripting State
late String script;
late ProviderManager providerManager;
late ModelManager modelManager;
late List<ModelInfo> Function() getUserSelectedModels;
late ChatHistory chatHistory;
late void Function(ChatMessage) addChatMessageToUI;
late bool debugMode;
Expand All @@ -38,6 +41,8 @@ class ChatScripting {
bool debugMode = false,
required ProviderManager providerManager,
required ChatHistory chatHistory,
required ModelManager modelManager,
required List<ModelInfo> Function() getUserSelectedModels,
required Function(ChatMessage) addChatMessageToUI,
}) async {
if (_instance != null) {
Expand All @@ -48,14 +53,16 @@ class ChatScripting {
instance.debugMode = debugMode;
instance.providerManager = providerManager;
instance.chatHistory = chatHistory;
instance.modelManager = modelManager;
instance.getUserSelectedModels = getUserSelectedModels;
instance.addChatMessageToUI = addChatMessageToUI;

// Install persistent callback functions
addGlobalBindings();

await instance.loadExtensionScript(url);
// Do one setup and evaluation of the script now
instance.updatePerCallBindings();
instance.evalExtensionScript();
}

Future<void> loadExtensionScript(String url) async {
Expand Down Expand Up @@ -91,46 +98,46 @@ class ChatScripting {

// Install the persistent callback functions
static void addGlobalBindings() {
addChatMessageJS = instance.addChatMessageFromJS.toJS;
sendMessagesToModelJS = instance.sendMessagesToModelFromJS.toJS;
addChatMessageJS = instance.addChatMessageJSImpl.toJS;
sendMessagesToModelJS = instance.sendMessagesToModelJSImpl.toJS;
getChatHistoryJS = instance.getChatHistoryJSImpl.toJS;
getUserSelectedModelsJS = instance.getUserSelectedModelsJSImpl.toJS;
}

// Items that need to be copied before each invocation of the JS scripting extension
void updatePerCallBindings({List<ModelInfo>? userSelectedModels}) {
chatHistoryJS =
ChatMessageJS.fromChatMessages(chatHistory.messages).jsify() as JSArray;
if (userSelectedModels != null) {
userSelectedModelsJS =
ModelInfoJS.fromModelInfos(userSelectedModels).jsify() as JSArray;
}
// Send the user prompt to the JS scripting extension
void sendUserPrompt(String userPrompt, List<ModelInfo> userSelectedModels) {
log("Invoke onUserPrompt on the scripting extension: $userPrompt");

// If debug mode evaluate the script before each usage
if (debugMode) {
evalExtensionScript();
}
}

// Send the user prompt to the JS scripting extension
void sendUserPrompt(String userPrompt, List<ModelInfo> userSelectedModels) {
log("Invoke onUserPrompt on the scripting extension: $userPrompt");
updatePerCallBindings(userSelectedModels: userSelectedModels);
onUserPromptJS(userPrompt);
}

///
/// BEGIN: callbacks from JS
/// BEGIN: JS callback implementations
///
// Implementation of the getChatHistory callback function invoked from JS
JSArray getChatHistoryJSImpl() =>
ChatMessageJS.fromChatMessages(chatHistory.messages).jsify() as JSArray;

// Implementation of the getUserSelectedModels callback function invoked from JS
JSArray getUserSelectedModelsJSImpl() =>
ModelInfoJS.fromModelInfos(getUserSelectedModels()).jsify() as JSArray;

// Implementation of the addChatMessage callback function invoked from JS
// Add a chat message to the local history.
void addChatMessageFromJS(ChatMessageJS message) {
void addChatMessageJSImpl(ChatMessageJS message) {
log("Add chat message: ${message.source}, ${message.msg}");
addChatMessageToUI(ChatMessageJS.toChatMessage(message));
// TODO: This can cause looping, let's invert the relevant calls (e.g. history) so that this is necessary.
// updatePerCallBindings(); // History has changed
}

// Implementation of sendMessagesToModel callback function invoked from JS
// Send a list of ChatMessage to a model for inference and return a promise of ChatMessageJS
JSPromise sendMessagesToModelFromJS(
JSPromise sendMessagesToModelJSImpl(
JSArray messagesJS, String modelId, int? maxTokens) {
log("dart: Send messages to model called.");

Expand Down Expand Up @@ -171,6 +178,6 @@ class ChatScripting {
}

///
/// END: callbacks from JS
/// END: JS callback implementations
///
}
38 changes: 35 additions & 3 deletions gai-frontend/lib/chat/scripting/chat_scripting_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class ModelInfo {
}

// List of ChatMessage structs, read-only
declare let chatHistory: ReadonlyArray<ChatMessage>;
declare function getChatHistory(): ReadonlyArray<ChatMessage>;

// List of ModelInfo user-selected models, read-only
declare let userSelectedModels: ReadonlyArray<ModelInfo>;
declare function getUserSelectedModels(): ReadonlyArray<ModelInfo>;

// Send a list of ChatMessage to a model for inference
declare function sendMessagesToModel(
Expand All @@ -58,12 +58,44 @@ declare function sendFormattedMessagesToModel(
// Add a chat message to the history
declare function addChatMessage(chatMessage: ChatMessage): void

// @ts-ignore
// Extension entry point: The user has hit enter on a new prompt.
declare function onUserPrompt(userPrompt: string): void

//
// Helper / util implementations
//

// Add a system message to the chat
function chatSystemMessage(message: string): void {
addChatMessage(new ChatMessage(ChatMessageSource.SYSTEM, message, {}));
}
// Add a provider message to the chat
function chatProviderMessage(message: string): void {
addChatMessage(new ChatMessage(ChatMessageSource.PROVIDER, message, {}));
}
// Add an internal message to the chat
function chatInternalMessage(message: string): void {
addChatMessage(new ChatMessage(ChatMessageSource.INTERNAL, message, {}));
}
// Add a client message to the chat
function chatClientMessage(message: string): void {
addChatMessage(new ChatMessage(ChatMessageSource.CLIENT, message, {}));
}

// Send a list of messages to a model for inference
function chatSendToModel(
messages: Array<ChatMessage>,
modelId: string,
maxTokens: number | null = null,
): Promise<ChatMessage> {
return sendMessagesToModel(messages, modelId, maxTokens);
}

// Get the conversation history for all models
function getConversation(): Array<ChatMessage> {
// Gather messages of source type 'client' or 'provider', irrespective of the provider model
return chatHistory.filter(
return getChatHistory().filter(
(message) =>
message.source === ChatMessageSource.CLIENT ||
message.source === ChatMessageSource.PROVIDER
Expand Down
20 changes: 11 additions & 9 deletions gai-frontend/lib/chat/scripting/extensions/filter_example.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,27 @@
function onUserPrompt(userPrompt: string): void {
(async () => {
// Log a system message to the chat
addChatMessage(new ChatMessage(ChatMessageSource.SYSTEM, 'Extension: Filter Example', {}));
chatSystemMessage('Extension: Filter Example');

// Send the message to 'gpt-4o' and ask it to characterize the user prompt.
let message = new ChatMessage(ChatMessageSource.CLIENT,
let decisionMessage = new ChatMessage(ChatMessageSource.CLIENT,
`The following is a user-generated prompt. Please characterize it as either friendly or ` +
`unfriendly and respond with just your one word decision: ` +
`{BEGIN_PROMPT}${userPrompt}{END_PROMPT}`,
{});
let response = await sendMessagesToModel([message], 'gpt-4o', null);
let response = await chatSendToModel([decisionMessage], 'gpt-4o');
let decision = response.msg.trim().toLowerCase();

// Log the decision to the chat
addChatMessage(new ChatMessage(ChatMessageSource.SYSTEM,
`Extension: User prompt evaluated as: ${decision}`, {}));
chatSystemMessage(`Extension: User prompt evaluated as: ${decision}`);

// Now send the prompt to the first user-selected model
const modelId = userSelectedModels[0].id;
message = new ChatMessage(ChatMessageSource.CLIENT, userPrompt, {});
addChatMessage(await sendMessagesToModel([message], modelId, null));
// Add the user prompt to the chat
let userMessage = new ChatMessage(ChatMessageSource.CLIENT, userPrompt, {});
addChatMessage(userMessage);

// Send it to the first user-selected model
const modelId = getUserSelectedModels()[0].id;
addChatMessage(await sendMessagesToModel([userMessage], modelId, null));

})();
}
Expand Down
Loading

0 comments on commit 7a3e504

Please sign in to comment.