forked from li2109/langtorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request li2109#132 from maxiao124/dev_minimax_mark
Dev minimax mark
Showing
43 changed files
with
1,584 additions
and
7 deletions.
There are no files selected for viewing
36 changes: 36 additions & 0 deletions
36
src/main/java/ai/knowly/langtorch/capability/integration/minimax/SimpleChatCapability.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package ai.knowly.langtorch.capability.integration.minimax; | ||
|
||
import ai.knowly.langtorch.capability.modality.text.MiniMaxChatCompletionLLMCapability; | ||
import ai.knowly.langtorch.capability.modality.text.Parsers; | ||
import ai.knowly.langtorch.preprocessing.parser.ChatMessageToStringParser; | ||
import ai.knowly.langtorch.preprocessing.parser.StringToMultiChatMessageParser; | ||
import ai.knowly.langtorch.processor.minimax.chat.MiniMaxChatProcessor; | ||
import ai.knowly.langtorch.schema.chat.ChatMessage; | ||
import ai.knowly.langtorch.schema.text.MultiChatMessage; | ||
import ai.knowly.langtorch.store.memory.conversation.ConversationMemory; | ||
|
||
/** | ||
* A simple chat capability unit that leverages minimax api to generate response | ||
* | ||
* @author maxiao | ||
* @date 2023/06/13 | ||
*/ | ||
public class SimpleChatCapability extends MiniMaxChatCompletionLLMCapability<String, String> { | ||
|
||
public SimpleChatCapability( | ||
MiniMaxChatProcessor miniMaxChatProcessor, ConversationMemory memory) { | ||
super( | ||
miniMaxChatProcessor, | ||
Parsers.<String, MultiChatMessage, ChatMessage, String>builder() | ||
.setInputParser(StringToMultiChatMessageParser.create()) | ||
.setOutputParser(ChatMessageToStringParser.create()) | ||
.build(), | ||
memory); | ||
} | ||
|
||
@Override | ||
public SimpleChatCapability withVerboseMode(boolean verboseMode) { | ||
super.withVerboseMode(verboseMode); | ||
return this; | ||
} | ||
} |
123 changes: 123 additions & 0 deletions
123
...java/ai/knowly/langtorch/capability/modality/text/MiniMaxChatCompletionLLMCapability.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package ai.knowly.langtorch.capability.modality.text; | ||
|
||
import static com.google.common.util.concurrent.Futures.immediateFuture; | ||
import static com.google.common.util.concurrent.MoreExecutors.directExecutor; | ||
|
||
import ai.knowly.langtorch.processor.Processor; | ||
import ai.knowly.langtorch.schema.chat.ChatMessage; | ||
import ai.knowly.langtorch.schema.text.MultiChatMessage; | ||
import ai.knowly.langtorch.store.memory.Memory; | ||
import ai.knowly.langtorch.store.memory.conversation.ConversationMemoryContext; | ||
import com.google.common.flogger.FluentLogger; | ||
import com.google.common.util.concurrent.FluentFuture; | ||
import com.google.common.util.concurrent.ListenableFuture; | ||
import com.google.inject.Inject; | ||
|
||
/** | ||
* @author maxiao | ||
* @date 2023/06/11 | ||
*/ | ||
public class MiniMaxChatCompletionLLMCapability<I, O> | ||
implements TextLLMCapabilityWithMemory< | ||
I, MultiChatMessage, ChatMessage, O, ChatMessage, ConversationMemoryContext> { | ||
|
||
private static final FluentLogger logger = FluentLogger.forEnclosingClass(); | ||
|
||
private final Processor<MultiChatMessage, ChatMessage> processor; | ||
private Parsers<I, MultiChatMessage, ChatMessage, O> parsers; | ||
private Memory<ChatMessage, ConversationMemoryContext> memory; | ||
private boolean verbose; | ||
|
||
@Inject | ||
public MiniMaxChatCompletionLLMCapability( | ||
Processor<MultiChatMessage, ChatMessage> processor, | ||
Parsers<I, MultiChatMessage, ChatMessage, O> parsers, | ||
Memory<ChatMessage, ConversationMemoryContext> memory) { | ||
this.processor = processor; | ||
this.parsers = parsers; | ||
this.memory = memory; | ||
this.verbose = false; | ||
} | ||
|
||
protected MiniMaxChatCompletionLLMCapability<I, O> withVerboseMode(boolean verbose) { | ||
this.verbose = verbose; | ||
return this; | ||
} | ||
|
||
@Override | ||
public O run(I inputData) { | ||
return postProcess(generateMemorySideEffectResponse(preProcess(inputData))); | ||
} | ||
|
||
private ChatMessage generateMemorySideEffectResponse(MultiChatMessage multiChatMessage) { | ||
if (verbose) { | ||
logger.atInfo().log("Memory before processing: %s", memory); | ||
} | ||
ChatMessage response = processor.run(getMessageWithMemorySideEffect(multiChatMessage)); | ||
// Adding prompt and response. | ||
multiChatMessage.getMessages().forEach(memory::add); | ||
ChatMessage message = ChatMessage.of(response.getContent(), response.getRole()); | ||
memory.add(message); | ||
return message; | ||
} | ||
|
||
private MultiChatMessage getMessageWithMemorySideEffect(MultiChatMessage message) { | ||
// Memory context being empty means that this is the first message in the conversation | ||
String memoryContext = memory.getMemoryContext().get(); | ||
if (memoryContext.isEmpty()) { | ||
return message; | ||
} | ||
|
||
MultiChatMessage updatedMessage = | ||
message.getMessages().stream() | ||
.map( | ||
chatMessage -> | ||
ChatMessage.of( | ||
String.format( | ||
"%s%nBelow is my query:%n%s", memoryContext, chatMessage.toString()), | ||
chatMessage.getRole())) | ||
.collect(MultiChatMessage.toMultiChatMessage()); | ||
|
||
if (verbose) { | ||
logger.atInfo().log("Updated Message with Memory Side Effect: %s", updatedMessage); | ||
} | ||
|
||
return updatedMessage; | ||
} | ||
|
||
@Override | ||
public MultiChatMessage preProcess(I inputData) { | ||
if (inputData instanceof MultiChatMessage) { | ||
return (MultiChatMessage) inputData; | ||
} | ||
|
||
return parsers | ||
.getInputParser() | ||
.map(parser -> parser.parse(inputData)) | ||
.orElseThrow( | ||
() -> | ||
new IllegalArgumentException( | ||
"Input data is not a MultiChatMessage and no input parser is present.")); | ||
} | ||
|
||
@Override | ||
public Memory<ChatMessage, ConversationMemoryContext> getMemory() { | ||
return memory; | ||
} | ||
|
||
@Override | ||
public O postProcess(ChatMessage outputData) { | ||
return parsers | ||
.getOutputParser() | ||
.map(parser -> parser.parse(outputData)) | ||
.orElseThrow( | ||
() -> | ||
new IllegalArgumentException( | ||
"Output data type is not ChatMessage and no output parser is present.")); | ||
} | ||
|
||
@Override | ||
public ListenableFuture<O> runAsync(I inputData) { | ||
return FluentFuture.from(immediateFuture(inputData)).transform(this::run, directExecutor()); | ||
} | ||
} |
24 changes: 24 additions & 0 deletions
24
src/main/java/ai/knowly/langtorch/llm/minimax/MiniMaxApi.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
package ai.knowly.langtorch.llm.minimax; | ||
|
||
import ai.knowly.langtorch.llm.minimax.schema.dto.completion.ChatCompletionRequest; | ||
import ai.knowly.langtorch.llm.minimax.schema.dto.completion.ChatCompletionResult; | ||
import ai.knowly.langtorch.llm.minimax.schema.dto.embedding.EmbeddingRequest; | ||
import ai.knowly.langtorch.llm.minimax.schema.dto.embedding.EmbeddingResult; | ||
import com.google.common.util.concurrent.ListenableFuture; | ||
import retrofit2.http.Body; | ||
import retrofit2.http.POST; | ||
|
||
/** | ||
* doc link: https://api.minimax.chat/document/guides | ||
* | ||
* @author maxiao | ||
* @date 2023/06/07 | ||
*/ | ||
public interface MiniMaxApi { | ||
|
||
@POST("/v1/text/chatcompletion") | ||
ListenableFuture<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request); | ||
|
||
@POST("/v1/embeddings") | ||
ListenableFuture<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request); | ||
} |
46 changes: 46 additions & 0 deletions
46
src/main/java/ai/knowly/langtorch/llm/minimax/MiniMaxAuthenticationInterceptor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package ai.knowly.langtorch.llm.minimax; | ||
|
||
import okhttp3.HttpUrl; | ||
import okhttp3.Interceptor; | ||
import okhttp3.Request; | ||
import okhttp3.Response; | ||
import org.jetbrains.annotations.NotNull; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
/** | ||
* OkHttp Interceptor that adds an authorization token header | ||
* | ||
* @author maxiao | ||
* @date 2023/06/07 | ||
*/ | ||
public class MiniMaxAuthenticationInterceptor implements Interceptor { | ||
|
||
private final String groupId; | ||
private final String apiKey; | ||
|
||
MiniMaxAuthenticationInterceptor(String groupId, String apiKey) { | ||
Objects.requireNonNull(groupId, "Minimax groupId required"); | ||
Objects.requireNonNull(apiKey, "Minimax apiKey required"); | ||
this.groupId = groupId; | ||
this.apiKey = apiKey; | ||
} | ||
|
||
@Override | ||
public Response intercept(@NotNull Chain chain) throws IOException { | ||
|
||
HttpUrl url = chain.request().url(); | ||
HttpUrl completeUrl = url.newBuilder().addQueryParameter("GroupId", groupId).build(); | ||
|
||
Request request = | ||
chain | ||
.request() | ||
.newBuilder() | ||
.url(completeUrl) | ||
.header("Authorization", "Bearer " + apiKey) | ||
.header("Content-Type", "application/json") | ||
.build(); | ||
return chain.proceed(request); | ||
} | ||
} |
151 changes: 151 additions & 0 deletions
151
src/main/java/ai/knowly/langtorch/llm/minimax/MiniMaxService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
package ai.knowly.langtorch.llm.minimax; | ||
|
||
import ai.knowly.langtorch.llm.minimax.schema.MiniMaxApiBusinessErrorException; | ||
import ai.knowly.langtorch.llm.minimax.schema.MiniMaxApiExecutionException; | ||
import ai.knowly.langtorch.llm.minimax.schema.MiniMaxApiServiceInterruptedException; | ||
import ai.knowly.langtorch.llm.minimax.schema.config.MiniMaxServiceConfig; | ||
import ai.knowly.langtorch.llm.minimax.schema.dto.BaseResp; | ||
import ai.knowly.langtorch.llm.minimax.schema.dto.completion.ChatCompletionRequest; | ||
import ai.knowly.langtorch.llm.minimax.schema.dto.completion.ChatCompletionResult; | ||
import ai.knowly.langtorch.llm.minimax.schema.dto.embedding.EmbeddingRequest; | ||
import ai.knowly.langtorch.llm.minimax.schema.dto.embedding.EmbeddingResult; | ||
import ai.knowly.langtorch.utils.future.retry.FutureRetrier; | ||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.databind.DeserializationFeature; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import com.fasterxml.jackson.databind.PropertyNamingStrategy; | ||
import com.google.common.flogger.FluentLogger; | ||
import com.google.common.util.concurrent.ListenableFuture; | ||
import com.google.inject.Inject; | ||
import java.io.IOException; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.concurrent.Executors; | ||
import java.util.concurrent.ScheduledExecutorService; | ||
import java.util.concurrent.TimeUnit; | ||
import okhttp3.ConnectionPool; | ||
import okhttp3.OkHttpClient; | ||
import retrofit2.HttpException; | ||
import retrofit2.Retrofit; | ||
import retrofit2.adapter.guava.GuavaCallAdapterFactory; | ||
import retrofit2.converter.jackson.JacksonConverterFactory; | ||
|
||
/** | ||
* MiniMaxService wraps MiniMaxApi and provides a synchronous and asynchronous interface to the | ||
* MiniMax API | ||
* | ||
* @author maxiao | ||
* @date 2023/06/07 | ||
*/ | ||
public class MiniMaxService { | ||
private static final FluentLogger logger = FluentLogger.forEnclosingClass(); | ||
private static final String BASE_URL = "https://api.minimax.chat"; | ||
|
||
private final MiniMaxApi api; | ||
private final FutureRetrier futureRetrier; | ||
|
||
private final ScheduledExecutorService scheduledExecutor; | ||
|
||
@Inject | ||
public MiniMaxService(final MiniMaxServiceConfig miniMaxServiceConfig) { | ||
ObjectMapper defaultObjectMapper = defaultObjectMapper(); | ||
OkHttpClient client = buildClient(miniMaxServiceConfig); | ||
Retrofit retrofit = defaultRetrofit(client, defaultObjectMapper); | ||
scheduledExecutor = Executors.newSingleThreadScheduledExecutor(); | ||
this.futureRetrier = | ||
new FutureRetrier( | ||
scheduledExecutor, | ||
miniMaxServiceConfig.backoffStrategy(), | ||
miniMaxServiceConfig.retryConfig()); | ||
this.api = retrofit.create(MiniMaxApi.class); | ||
} | ||
|
||
public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) { | ||
return new Retrofit.Builder() | ||
.baseUrl(BASE_URL) | ||
.client(client) | ||
.addConverterFactory(JacksonConverterFactory.create(mapper)) | ||
.addCallAdapterFactory(GuavaCallAdapterFactory.create()) | ||
.build(); | ||
} | ||
|
||
public static OkHttpClient buildClient(MiniMaxServiceConfig miniMaxServiceConfig) { | ||
OkHttpClient.Builder builder = | ||
new OkHttpClient.Builder() | ||
.addInterceptor( | ||
new MiniMaxAuthenticationInterceptor( | ||
miniMaxServiceConfig.groupId(), miniMaxServiceConfig.apiKey())) | ||
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS)) | ||
.readTimeout(miniMaxServiceConfig.timeoutDuration().toMillis(), TimeUnit.MILLISECONDS); | ||
|
||
return builder.build(); | ||
} | ||
|
||
public static ObjectMapper defaultObjectMapper() { | ||
ObjectMapper mapper = new ObjectMapper(); | ||
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); | ||
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); | ||
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); | ||
return mapper; | ||
} | ||
|
||
public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) { | ||
ChatCompletionResult chatCompletionResult = | ||
execute( | ||
futureRetrier.runWithRetries(() -> api.createChatCompletion(request), result -> true)); | ||
|
||
checkResp(chatCompletionResult.getBaseResp()); | ||
return chatCompletionResult; | ||
} | ||
|
||
public ListenableFuture<ChatCompletionResult> createChatCompletionAsync( | ||
ChatCompletionRequest request) { | ||
return futureRetrier.runWithRetries(() -> api.createChatCompletion(request), result -> true); | ||
} | ||
|
||
public EmbeddingResult createEmbeddings(EmbeddingRequest request) { | ||
EmbeddingResult embeddingResult = | ||
execute(futureRetrier.runWithRetries(() -> api.createEmbeddings(request), result -> true)); | ||
|
||
checkResp(embeddingResult.getBaseResp()); | ||
return embeddingResult; | ||
} | ||
|
||
public ListenableFuture<EmbeddingResult> createEmbeddingsAsync(EmbeddingRequest request) { | ||
return futureRetrier.runWithRetries(() -> api.createEmbeddings(request), result -> true); | ||
} | ||
|
||
/** Throw exception messages if the request fails */ | ||
public void checkResp(BaseResp baseResp) { | ||
if (baseResp.getStatusCode() != 0) { | ||
throw new MiniMaxApiBusinessErrorException(baseResp.getStatusCode(), baseResp.getStatusMsg()); | ||
} | ||
} | ||
|
||
/** | ||
* Calls the MiniMax AI api, returns the response, and parses error messages if the request fails | ||
*/ | ||
public static <T> T execute(ListenableFuture<T> apiCall) { | ||
try { | ||
return apiCall.get(); | ||
} catch (InterruptedException e) { | ||
// Restore the interrupt status | ||
Thread.currentThread().interrupt(); | ||
// Optionally, log or handle the exception here. | ||
logger.atSevere().withCause(e).log("Thread was interrupted during API call."); | ||
throw new MiniMaxApiServiceInterruptedException(e); | ||
|
||
} catch (ExecutionException e) { | ||
if (e.getCause() instanceof HttpException) { | ||
HttpException httpException = (HttpException) e.getCause(); | ||
try { | ||
String errorBody = httpException.response().errorBody().string(); | ||
logger.atSevere().log("HTTP Error: %s", errorBody); | ||
} catch (IOException ioException) { | ||
logger.atSevere().withCause(ioException).log("Error while reading errorBody"); | ||
} | ||
} | ||
|
||
throw new MiniMaxApiExecutionException(e); | ||
} | ||
} | ||
} |
14 changes: 14 additions & 0 deletions
14
src/main/java/ai/knowly/langtorch/llm/minimax/schema/MiniMaxApiBusinessErrorException.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package ai.knowly.langtorch.llm.minimax.schema; | ||
/** | ||
* @author maxiao | ||
* @date 2023/06/17 | ||
*/ | ||
public class MiniMaxApiBusinessErrorException extends RuntimeException { | ||
|
||
final Long statusCode; | ||
|
||
public MiniMaxApiBusinessErrorException(Long statusCode, String statusMessage) { | ||
super(statusMessage); | ||
this.statusCode = statusCode; | ||
} | ||
} |
Oops, something went wrong.