forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support GoogleSearchAPIWrapper HamaWhiteGG#110
- Loading branch information
1 parent
a92a484
commit d815b75
Showing
6 changed files
with
279 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
198 changes: 198 additions & 0 deletions
198
langchain-core/src/main/java/com/hw/langchain/utilities/GoogleSearchAPIWrapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.utilities; | ||
|
||
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport; | ||
import com.google.api.client.http.HttpRequestInitializer; | ||
import com.google.api.client.http.HttpTransport; | ||
import com.google.api.client.json.gson.GsonFactory; | ||
import com.google.api.services.customsearch.v1.CustomSearchAPI; | ||
import com.google.api.services.customsearch.v1.CustomSearchAPIRequestInitializer; | ||
import com.google.api.services.customsearch.v1.model.Result; | ||
import com.google.common.collect.Maps; | ||
|
||
import org.apache.commons.collections4.CollectionUtils; | ||
|
||
import lombok.SneakyThrows; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.stream.Collectors; | ||
|
||
import static com.hw.langchain.utils.Utils.getOrEnvOrDefault; | ||
|
||
/** | ||
* Wrapper for Google Search API. | ||
* | ||
* <ol> | ||
* <li>Create an API key</li> | ||
* <li>Setup Custom Search Engine to search the entire web</li> | ||
* <li>Enable the Custom Search API</li> | ||
* </ol> | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class GoogleSearchAPIWrapper { | ||
|
||
private final CustomSearchAPI customSearch; | ||
|
||
private final String googleCseId; | ||
|
||
/** | ||
* Number of search results to return. | ||
*/ | ||
private final int num; | ||
|
||
@SneakyThrows | ||
private GoogleSearchAPIWrapper(Builder builder) { | ||
HttpTransport transport = GoogleNetHttpTransport.newTrustedTransport(); | ||
HttpRequestInitializer httpRequestInitializer = request -> { | ||
request.setConnectTimeout(builder.connectTimeout); | ||
request.setReadTimeout(builder.readTimeout); | ||
}; | ||
String googleApiKey = getOrEnvOrDefault(builder.googleApiKey, "GOOGLE_API_KEY"); | ||
|
||
this.customSearch = new CustomSearchAPI.Builder(transport, new GsonFactory(), httpRequestInitializer) | ||
.setApplicationName("Google Custom Search") | ||
.setGoogleClientRequestInitializer(new CustomSearchAPIRequestInitializer(googleApiKey)) | ||
.build(); | ||
|
||
this.googleCseId = getOrEnvOrDefault(builder.googleCseId, "GOOGLE_CSE_ID"); | ||
this.num = builder.num; | ||
} | ||
|
||
@SneakyThrows | ||
private List<Result> googleSearchResults(String query, int num) { | ||
return customSearch.cse() | ||
.list() | ||
.setCx(googleCseId) | ||
.setQ(query) | ||
.setNum(num) | ||
.execute() | ||
.getItems(); | ||
} | ||
|
||
/** | ||
* Run a query through Google Search and parse the results. | ||
* | ||
* @param query The search query to be executed. | ||
* @return A string containing snippets from the Google Search results, joined with spaces. | ||
*/ | ||
public String run(String query) { | ||
List<Result> results = googleSearchResults(query, num); | ||
if (CollectionUtils.isEmpty(results)) { | ||
return "No good Google Search Result was found"; | ||
} | ||
return results.stream() | ||
.map(Result::getSnippet) | ||
.filter(Objects::nonNull) | ||
.collect(Collectors.joining(" ")); | ||
} | ||
|
||
/** | ||
* Run query through GoogleSearch and return metadata. | ||
* | ||
* @param query The query to search for. | ||
* @param num The number of results to return. | ||
* @return A list of dictionaries with the following keys: | ||
* <ul> | ||
* <li>snippet - The description of the result.</li> | ||
* <li>title - The title of the result.</li> | ||
* <li>link - The link to the result.</li> | ||
* </ul> | ||
*/ | ||
public List<Map<String, String>> results(String query, int num) { | ||
List<Result> results = googleSearchResults(query, num); | ||
|
||
if (CollectionUtils.isEmpty(results)) { | ||
return List.of(Map.of("Result", "No good Google Search Result was found")); | ||
} | ||
return results.stream().map(result -> { | ||
Map<String, String> metadataResult = Maps.newHashMap(); | ||
metadataResult.put("title", result.getTitle()); | ||
metadataResult.put("link", result.getLink()); | ||
|
||
if (result.getSnippet() != null) { | ||
metadataResult.put("snippet", result.getSnippet()); | ||
} | ||
return metadataResult; | ||
}).toList(); | ||
} | ||
|
||
public static final class Builder { | ||
|
||
/** | ||
* Configure or set the environment variable GOOGLE_API_KEY. | ||
*/ | ||
private String googleApiKey; | ||
|
||
/** | ||
* Configure or set the environment variable GOOGLE_CSE_ID. | ||
*/ | ||
private String googleCseId; | ||
|
||
/** | ||
* Number of search results to return. | ||
*/ | ||
private int num = 10; | ||
|
||
/** | ||
* Timeout in milliseconds to establish a connection or {@code 0} for an infinite timeout. | ||
*/ | ||
private int connectTimeout = 20 * 1000; | ||
|
||
/** | ||
* Timeout in milliseconds to read data from an established connection or {@code 0} for an infinite timeout. | ||
*/ | ||
private int readTimeout = 20 * 1000; | ||
|
||
public Builder() { | ||
} | ||
|
||
public Builder googleApiKey(String googleApiKey) { | ||
this.googleApiKey = googleApiKey; | ||
return this; | ||
} | ||
|
||
public Builder googleCseId(String googleCseId) { | ||
this.googleCseId = googleCseId; | ||
return this; | ||
} | ||
|
||
public Builder num(int num) { | ||
this.num = num; | ||
return this; | ||
} | ||
|
||
public Builder connectTimeout(int connectTimeout) { | ||
this.connectTimeout = connectTimeout; | ||
return this; | ||
} | ||
|
||
public Builder readTimeout(int readTimeout) { | ||
this.readTimeout = readTimeout; | ||
return this; | ||
} | ||
|
||
public GoogleSearchAPIWrapper build() { | ||
return new GoogleSearchAPIWrapper(this); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
langchain-core/src/test/java/com/hw/langchain/utilities/GoogleSearchAPIWrapperTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.utilities; | ||
|
||
import org.junit.jupiter.api.Disabled; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
/** | ||
* Set the environment variable GOOGLE_API_KEY and GOOGLE_CSE_ID. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@Disabled("Test requires Google Search Key , can be run manually.") | ||
class GoogleSearchAPIWrapperTest { | ||
|
||
@Test | ||
void testGoogleSearch() { | ||
String query = "2022 USA national auto sales by brand"; | ||
|
||
GoogleSearchAPIWrapper searchWrapper = new GoogleSearchAPIWrapper.Builder() | ||
.connectTimeout(10 * 1000) | ||
.build(); | ||
|
||
List<Map<String, String>> results = searchWrapper.results(query, 10); | ||
assertThat(results).isNotEmpty().hasSize(10); | ||
|
||
assertThat(results.get(0)).isEqualTo(Map.of( | ||
"title", "Full-Year 2022 National Auto Sales By Brand", | ||
"link", "https://www.carpro.com/blog/full-year-2022-national-auto-sales-by-brand", | ||
"snippet", | ||
"Jan 12, 2023 ... Full-Year 2022 National Auto Sales By Brand ; 1. Toyota, 1,849,751 ; 2. Ford, 1,767,439 ; 3. Chevrolet, 1,502,389 ; 4. Honda, 881,201 ...")); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters