Skip to content

Commit

Permalink
Merge pull request li2109#132 from maxiao124/dev_minimax_mark
Browse files Browse the repository at this point in the history
Dev minimax mark
li2109 authored Jun 18, 2023
2 parents d816cfc + c638fdd commit c486fa9
Showing 43 changed files with 1,584 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package ai.knowly.langtorch.capability.integration.minimax;

import ai.knowly.langtorch.capability.modality.text.MiniMaxChatCompletionLLMCapability;
import ai.knowly.langtorch.capability.modality.text.Parsers;
import ai.knowly.langtorch.preprocessing.parser.ChatMessageToStringParser;
import ai.knowly.langtorch.preprocessing.parser.StringToMultiChatMessageParser;
import ai.knowly.langtorch.processor.minimax.chat.MiniMaxChatProcessor;
import ai.knowly.langtorch.schema.chat.ChatMessage;
import ai.knowly.langtorch.schema.text.MultiChatMessage;
import ai.knowly.langtorch.store.memory.conversation.ConversationMemory;

/**
* A simple chat capability unit that leverages minimax api to generate response
*
* @author maxiao
* @date 2023/06/13
*/
public class SimpleChatCapability extends MiniMaxChatCompletionLLMCapability<String, String> {

public SimpleChatCapability(
MiniMaxChatProcessor miniMaxChatProcessor, ConversationMemory memory) {
super(
miniMaxChatProcessor,
Parsers.<String, MultiChatMessage, ChatMessage, String>builder()
.setInputParser(StringToMultiChatMessageParser.create())
.setOutputParser(ChatMessageToStringParser.create())
.build(),
memory);
}

@Override
public SimpleChatCapability withVerboseMode(boolean verboseMode) {
super.withVerboseMode(verboseMode);
return this;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package ai.knowly.langtorch.capability.modality.text;

import static com.google.common.util.concurrent.Futures.immediateFuture;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;

import ai.knowly.langtorch.processor.Processor;
import ai.knowly.langtorch.schema.chat.ChatMessage;
import ai.knowly.langtorch.schema.text.MultiChatMessage;
import ai.knowly.langtorch.store.memory.Memory;
import ai.knowly.langtorch.store.memory.conversation.ConversationMemoryContext;
import com.google.common.flogger.FluentLogger;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Inject;

/**
* @author maxiao
* @date 2023/06/11
*/
public class MiniMaxChatCompletionLLMCapability<I, O>
implements TextLLMCapabilityWithMemory<
I, MultiChatMessage, ChatMessage, O, ChatMessage, ConversationMemoryContext> {

private static final FluentLogger logger = FluentLogger.forEnclosingClass();

private final Processor<MultiChatMessage, ChatMessage> processor;
private Parsers<I, MultiChatMessage, ChatMessage, O> parsers;
private Memory<ChatMessage, ConversationMemoryContext> memory;
private boolean verbose;

@Inject
public MiniMaxChatCompletionLLMCapability(
Processor<MultiChatMessage, ChatMessage> processor,
Parsers<I, MultiChatMessage, ChatMessage, O> parsers,
Memory<ChatMessage, ConversationMemoryContext> memory) {
this.processor = processor;
this.parsers = parsers;
this.memory = memory;
this.verbose = false;
}

protected MiniMaxChatCompletionLLMCapability<I, O> withVerboseMode(boolean verbose) {
this.verbose = verbose;
return this;
}

@Override
public O run(I inputData) {
return postProcess(generateMemorySideEffectResponse(preProcess(inputData)));
}

private ChatMessage generateMemorySideEffectResponse(MultiChatMessage multiChatMessage) {
if (verbose) {
logger.atInfo().log("Memory before processing: %s", memory);
}
ChatMessage response = processor.run(getMessageWithMemorySideEffect(multiChatMessage));
// Adding prompt and response.
multiChatMessage.getMessages().forEach(memory::add);
ChatMessage message = ChatMessage.of(response.getContent(), response.getRole());
memory.add(message);
return message;
}

private MultiChatMessage getMessageWithMemorySideEffect(MultiChatMessage message) {
// Memory context being empty means that this is the first message in the conversation
String memoryContext = memory.getMemoryContext().get();
if (memoryContext.isEmpty()) {
return message;
}

MultiChatMessage updatedMessage =
message.getMessages().stream()
.map(
chatMessage ->
ChatMessage.of(
String.format(
"%s%nBelow is my query:%n%s", memoryContext, chatMessage.toString()),
chatMessage.getRole()))
.collect(MultiChatMessage.toMultiChatMessage());

if (verbose) {
logger.atInfo().log("Updated Message with Memory Side Effect: %s", updatedMessage);
}

return updatedMessage;
}

@Override
public MultiChatMessage preProcess(I inputData) {
if (inputData instanceof MultiChatMessage) {
return (MultiChatMessage) inputData;
}

return parsers
.getInputParser()
.map(parser -> parser.parse(inputData))
.orElseThrow(
() ->
new IllegalArgumentException(
"Input data is not a MultiChatMessage and no input parser is present."));
}

@Override
public Memory<ChatMessage, ConversationMemoryContext> getMemory() {
return memory;
}

@Override
public O postProcess(ChatMessage outputData) {
return parsers
.getOutputParser()
.map(parser -> parser.parse(outputData))
.orElseThrow(
() ->
new IllegalArgumentException(
"Output data type is not ChatMessage and no output parser is present."));
}

@Override
public ListenableFuture<O> runAsync(I inputData) {
return FluentFuture.from(immediateFuture(inputData)).transform(this::run, directExecutor());
}
}
24 changes: 24 additions & 0 deletions src/main/java/ai/knowly/langtorch/llm/minimax/MiniMaxApi.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package ai.knowly.langtorch.llm.minimax;

import ai.knowly.langtorch.llm.minimax.schema.dto.completion.ChatCompletionRequest;
import ai.knowly.langtorch.llm.minimax.schema.dto.completion.ChatCompletionResult;
import ai.knowly.langtorch.llm.minimax.schema.dto.embedding.EmbeddingRequest;
import ai.knowly.langtorch.llm.minimax.schema.dto.embedding.EmbeddingResult;
import com.google.common.util.concurrent.ListenableFuture;
import retrofit2.http.Body;
import retrofit2.http.POST;

/**
* doc link: https://api.minimax.chat/document/guides
*
* @author maxiao
* @date 2023/06/07
*/
public interface MiniMaxApi {

@POST("/v1/text/chatcompletion")
ListenableFuture<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request);

@POST("/v1/embeddings")
ListenableFuture<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package ai.knowly.langtorch.llm.minimax;

import okhttp3.HttpUrl;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.util.Objects;

/**
* OkHttp Interceptor that adds an authorization token header
*
* @author maxiao
* @date 2023/06/07
*/
public class MiniMaxAuthenticationInterceptor implements Interceptor {

private final String groupId;
private final String apiKey;

MiniMaxAuthenticationInterceptor(String groupId, String apiKey) {
Objects.requireNonNull(groupId, "Minimax groupId required");
Objects.requireNonNull(apiKey, "Minimax apiKey required");
this.groupId = groupId;
this.apiKey = apiKey;
}

@Override
public Response intercept(@NotNull Chain chain) throws IOException {

HttpUrl url = chain.request().url();
HttpUrl completeUrl = url.newBuilder().addQueryParameter("GroupId", groupId).build();

Request request =
chain
.request()
.newBuilder()
.url(completeUrl)
.header("Authorization", "Bearer " + apiKey)
.header("Content-Type", "application/json")
.build();
return chain.proceed(request);
}
}
151 changes: 151 additions & 0 deletions src/main/java/ai/knowly/langtorch/llm/minimax/MiniMaxService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package ai.knowly.langtorch.llm.minimax;

import ai.knowly.langtorch.llm.minimax.schema.MiniMaxApiBusinessErrorException;
import ai.knowly.langtorch.llm.minimax.schema.MiniMaxApiExecutionException;
import ai.knowly.langtorch.llm.minimax.schema.MiniMaxApiServiceInterruptedException;
import ai.knowly.langtorch.llm.minimax.schema.config.MiniMaxServiceConfig;
import ai.knowly.langtorch.llm.minimax.schema.dto.BaseResp;
import ai.knowly.langtorch.llm.minimax.schema.dto.completion.ChatCompletionRequest;
import ai.knowly.langtorch.llm.minimax.schema.dto.completion.ChatCompletionResult;
import ai.knowly.langtorch.llm.minimax.schema.dto.embedding.EmbeddingRequest;
import ai.knowly.langtorch.llm.minimax.schema.dto.embedding.EmbeddingResult;
import ai.knowly.langtorch.utils.future.retry.FutureRetrier;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.google.common.flogger.FluentLogger;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Inject;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import retrofit2.HttpException;
import retrofit2.Retrofit;
import retrofit2.adapter.guava.GuavaCallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;

/**
* MiniMaxService wraps MiniMaxApi and provides a synchronous and asynchronous interface to the
* MiniMax API
*
* @author maxiao
* @date 2023/06/07
*/
public class MiniMaxService {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private static final String BASE_URL = "https://api.minimax.chat";

private final MiniMaxApi api;
private final FutureRetrier futureRetrier;

private final ScheduledExecutorService scheduledExecutor;

@Inject
public MiniMaxService(final MiniMaxServiceConfig miniMaxServiceConfig) {
ObjectMapper defaultObjectMapper = defaultObjectMapper();
OkHttpClient client = buildClient(miniMaxServiceConfig);
Retrofit retrofit = defaultRetrofit(client, defaultObjectMapper);
scheduledExecutor = Executors.newSingleThreadScheduledExecutor();
this.futureRetrier =
new FutureRetrier(
scheduledExecutor,
miniMaxServiceConfig.backoffStrategy(),
miniMaxServiceConfig.retryConfig());
this.api = retrofit.create(MiniMaxApi.class);
}

public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) {
return new Retrofit.Builder()
.baseUrl(BASE_URL)
.client(client)
.addConverterFactory(JacksonConverterFactory.create(mapper))
.addCallAdapterFactory(GuavaCallAdapterFactory.create())
.build();
}

public static OkHttpClient buildClient(MiniMaxServiceConfig miniMaxServiceConfig) {
OkHttpClient.Builder builder =
new OkHttpClient.Builder()
.addInterceptor(
new MiniMaxAuthenticationInterceptor(
miniMaxServiceConfig.groupId(), miniMaxServiceConfig.apiKey()))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
.readTimeout(miniMaxServiceConfig.timeoutDuration().toMillis(), TimeUnit.MILLISECONDS);

return builder.build();
}

public static ObjectMapper defaultObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
return mapper;
}

public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
ChatCompletionResult chatCompletionResult =
execute(
futureRetrier.runWithRetries(() -> api.createChatCompletion(request), result -> true));

checkResp(chatCompletionResult.getBaseResp());
return chatCompletionResult;
}

public ListenableFuture<ChatCompletionResult> createChatCompletionAsync(
ChatCompletionRequest request) {
return futureRetrier.runWithRetries(() -> api.createChatCompletion(request), result -> true);
}

public EmbeddingResult createEmbeddings(EmbeddingRequest request) {
EmbeddingResult embeddingResult =
execute(futureRetrier.runWithRetries(() -> api.createEmbeddings(request), result -> true));

checkResp(embeddingResult.getBaseResp());
return embeddingResult;
}

public ListenableFuture<EmbeddingResult> createEmbeddingsAsync(EmbeddingRequest request) {
return futureRetrier.runWithRetries(() -> api.createEmbeddings(request), result -> true);
}

/** Throw exception messages if the request fails */
public void checkResp(BaseResp baseResp) {
if (baseResp.getStatusCode() != 0) {
throw new MiniMaxApiBusinessErrorException(baseResp.getStatusCode(), baseResp.getStatusMsg());
}
}

/**
* Calls the MiniMax AI api, returns the response, and parses error messages if the request fails
*/
public static <T> T execute(ListenableFuture<T> apiCall) {
try {
return apiCall.get();
} catch (InterruptedException e) {
// Restore the interrupt status
Thread.currentThread().interrupt();
// Optionally, log or handle the exception here.
logger.atSevere().withCause(e).log("Thread was interrupted during API call.");
throw new MiniMaxApiServiceInterruptedException(e);

} catch (ExecutionException e) {
if (e.getCause() instanceof HttpException) {
HttpException httpException = (HttpException) e.getCause();
try {
String errorBody = httpException.response().errorBody().string();
logger.atSevere().log("HTTP Error: %s", errorBody);
} catch (IOException ioException) {
logger.atSevere().withCause(ioException).log("Error while reading errorBody");
}
}

throw new MiniMaxApiExecutionException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package ai.knowly.langtorch.llm.minimax.schema;
/**
* @author maxiao
* @date 2023/06/17
*/
public class MiniMaxApiBusinessErrorException extends RuntimeException {

final Long statusCode;

public MiniMaxApiBusinessErrorException(Long statusCode, String statusMessage) {
super(statusMessage);
this.statusCode = statusCode;
}
}
Loading
Oops, something went wrong.

0 comments on commit c486fa9

Please sign in to comment.