Skip to content

Commit

Permalink
Throw error if more than one retriever/contentRetriever/retrievalAugm… (
Browse files Browse the repository at this point in the history
langchain4j#710)

…entor are set

As we discussed in
quarkiverse/quarkus-langchain4j#353 (comment)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Enhanced AI services with new validation rules to ensure exclusive
setting of components, improving system stability and predictability.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
jmartisk authored Mar 8, 2024
1 parent 86f314e commit 2bd7fb1
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
16 changes: 16 additions & 0 deletions langchain4j/src/main/java/dev/langchain4j/service/AiServices.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ public abstract class AiServices<T> {

protected final AiServiceContext context;

private boolean retrieverSet = false;
private boolean contentRetrieverSet = false;
private boolean retrievalAugmentorSet = false;

protected AiServices(AiServiceContext context) {
this.context = context;
}
Expand Down Expand Up @@ -311,7 +315,11 @@ public AiServices<T> tools(List<Object> objectsWithTools) {
*/
@Deprecated
public AiServices<T> retriever(Retriever<TextSegment> retriever) {
if(contentRetrieverSet || retrievalAugmentorSet) {
throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
}
if (retriever != null) {
retrieverSet = true;
return contentRetriever(retriever.toContentRetriever());
}
return this;
Expand All @@ -331,6 +339,10 @@ public AiServices<T> retriever(Retriever<TextSegment> retriever) {
* @return builder
*/
public AiServices<T> contentRetriever(ContentRetriever contentRetriever) {
if(retrieverSet || retrievalAugmentorSet) {
throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
}
contentRetrieverSet = true;
context.retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.contentRetriever(ensureNotNull(contentRetriever, "contentRetriever"))
.build();
Expand All @@ -344,6 +356,10 @@ public AiServices<T> contentRetriever(ContentRetriever contentRetriever) {
* @return builder
*/
public AiServices<T> retrievalAugmentor(RetrievalAugmentor retrievalAugmentor) {
if(retrieverSet || contentRetrieverSet) {
throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
}
retrievalAugmentorSet = true;
context.retrievalAugmentor = ensureNotNull(retrievalAugmentor, "retrievalAugmentor");
return this;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package dev.langchain4j.service;

import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.retriever.Retriever;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;

/**
* Verify that the AIServices builder doesn't allow setting more than out of
* (retriever, contentRetriever, retrievalAugmentor).
*/
public class AiServicesBuilderTest {

@Test
public void testRetrieverAndContentRetriever() {
Retriever retriever = mock(Retriever.class);
ContentRetriever contentRetriever = mock(ContentRetriever.class);

assertThrows(IllegalConfigurationException.class, () -> {
AiServices.builder(AiServices.class)
.retriever(retriever)
.contentRetriever(contentRetriever)
.build();
});
}

@Test
public void testRetrieverAndRetrievalAugmentor() {
Retriever retriever = mock(Retriever.class);
RetrievalAugmentor retrievalAugmentor = mock(RetrievalAugmentor.class);

assertThrows(IllegalConfigurationException.class, () -> {
AiServices.builder(AiServices.class)
.retriever(retriever)
.retrievalAugmentor(retrievalAugmentor)
.build();
});
}

@Test
public void testContentRetrieverAndRetrievalAugmentor() {
ContentRetriever contentRetriever = mock(ContentRetriever.class);
RetrievalAugmentor retrievalAugmentor = mock(RetrievalAugmentor.class);

assertThrows(IllegalConfigurationException.class, () -> {
AiServices.builder(AiServices.class)
.contentRetriever(contentRetriever)
.retrievalAugmentor(retrievalAugmentor)
.build();
});
}

@Test
public void testContentRetrieverAndRetriever() {
Retriever retriever = mock(Retriever.class);
ContentRetriever contentRetriever = mock(ContentRetriever.class);

assertThrows(IllegalConfigurationException.class, () -> {
AiServices.builder(AiServices.class)
.contentRetriever(contentRetriever)
.retriever(retriever)
.build();
});
}

@Test
public void testRetrievalAugmentorAndRetriever() {
Retriever retriever = mock(Retriever.class);
RetrievalAugmentor retrievalAugmentor = mock(RetrievalAugmentor.class);

assertThrows(IllegalConfigurationException.class, () -> {
AiServices.builder(AiServices.class)
.retrievalAugmentor(retrievalAugmentor)
.retriever(retriever)
.build();
});
}

@Test
public void testRetrievalAugmentorAndContentRetriever() {
ContentRetriever contentRetriever = mock(ContentRetriever.class);
RetrievalAugmentor retrievalAugmentor = mock(RetrievalAugmentor.class);

assertThrows(IllegalConfigurationException.class, () -> {
AiServices.builder(AiServices.class)
.retrievalAugmentor(retrievalAugmentor)
.contentRetriever(contentRetriever)
.build();
});
}

}

0 comments on commit 2bd7fb1

Please sign in to comment.