From 4f490a0aef23ce64c2c60791b1ed7164ea2a3dc4 Mon Sep 17 00:00:00 2001 From: HamaWhite Date: Tue, 18 Jul 2023 00:13:43 +0800 Subject: [PATCH] Support API chains #26 --- README.md | 9 +- .../langchain/chains/api/base/ApiChain.java | 133 ++++++++++++++++++ .../chains/api/meteo/OpenMeteoDocs.java | 63 +++++++++ .../langchain/chains/api/prompt/Prompt.java | 62 ++++++++ .../com/hw/langchain/requests/Requests.java | 93 ++++++++++++ .../requests/TextRequestsWrapper.java | 85 +++++++++++ .../chains/api/base/ApiChainTest.java | 54 +++++++ .../examples/chains/ApiChainExample.java | 46 ++++++ 8 files changed, 540 insertions(+), 5 deletions(-) create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/api/meteo/OpenMeteoDocs.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/api/prompt/Prompt.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/requests/Requests.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/requests/TextRequestsWrapper.java create mode 100644 langchain-core/src/test/java/com/hw/langchain/chains/api/base/ApiChainTest.java create mode 100644 langchain-examples/src/main/java/com/hw/langchain/examples/chains/ApiChainExample.java diff --git a/README.md b/README.md index 0044f9909..59bd5489a 100644 --- a/README.md +++ b/README.md @@ -8,18 +8,17 @@ This is the Java language implementation of LangChain. Large language models (LLMs) are emerging as a transformative technology, enabling developers to build applications that they previously could not. But using these LLMs in isolation is often not enough to create a truly powerful app - the real power comes when you can combine them with other sources of computation or knowledge. -This library is aimed at assisting in the development of those types of applications. +This library is aimed at assisting in the development of those types of applications. Looking for the Python version? Check out [LangChain](https://github.com/hwchase17/langchain). -Looking for the Python version? Check out [LangChain](https://github.com/hwchase17/langchain). - -The following example can view in the [langchain-example](langchain-examples/src/main/java/com/hw/langchain/examples) +The following example can view in the [langchain-example](langchain-examples/src/main/java/com/hw/langchain/examples) ## 2. User cases - [SQL Chains](langchain-examples/src/main/java/com/hw/langchain/examples/chains/SqlChainExample.java) +- [API Chains](langchain-examples/src/main/java/com/hw/langchain/examples/chains/ApiChainExample.java) - [Agent with Google Search](langchain-examples/src/main/java/com/hw/langchain/examples/agents/LlmAgentExample.java) - [Question answering over documents](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java) -- [Context aware text splitting and QA / Chat](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java) +- [Context aware text splitting and QA](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java) ## 3. Quickstart Guide This tutorial gives you a quick walkthrough about building an end-to-end language model application with LangChain. diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java new file mode 100644 index 000000000..397a77357 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.api.base; + +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.chains.base.Chain; +import com.hw.langchain.chains.llm.LLMChain; +import com.hw.langchain.prompts.base.BasePromptTemplate; +import com.hw.langchain.requests.TextRequestsWrapper; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.hw.langchain.chains.api.prompt.Prompt.API_RESPONSE_PROMPT; +import static com.hw.langchain.chains.api.prompt.Prompt.API_URL_PROMPT; + +/** + * Chain that makes API calls and summarizes the responses to answer a question. + * + * @author HamaWhite + */ +public class ApiChain extends Chain { + + private final LLMChain apiRequestChain; + + private final LLMChain apiAnswerChain; + + private final TextRequestsWrapper requestsWrapper; + + private final String apiDocs; + + private static final String QUESTION_KEY = "question"; + + private static final String OUTPUT_KEY = "output"; + + private static final String API_DOCS = "api_docs"; + + public ApiChain(LLMChain apiRequestChain, LLMChain apiAnswerChain, TextRequestsWrapper requestsWrapper, + String apiDocs) { + this.apiRequestChain = apiRequestChain; + this.apiAnswerChain = apiAnswerChain; + this.requestsWrapper = requestsWrapper; + this.apiDocs = apiDocs; + + // check that api request prompt expects the right variables. + validateApiRequestPrompt(); + // check that api answer prompt expects the right variables. + validateApiAnswerPrompt(); + } + + /** + * Check that api request prompt expects the right variables. + */ + private void validateApiRequestPrompt() { + List inputVars = apiRequestChain.getPrompt().getInputVariables(); + Set inputVarsSet = new HashSet<>(inputVars); + + Set expectedVars = Set.of(QUESTION_KEY, API_DOCS); + if (!inputVarsSet.equals(expectedVars)) { + throw new IllegalArgumentException("Input variables should be " + expectedVars + ", got " + inputVars); + } + } + + /** + * Check that api answer prompt expects the right variables. + */ + private void validateApiAnswerPrompt() { + List inputVars = apiAnswerChain.getPrompt().getInputVariables(); + Set inputVarsSet = new HashSet<>(inputVars); + + Set expectedVars = Set.of(QUESTION_KEY, API_DOCS, "api_url", "api_response"); + if (!inputVarsSet.equals(expectedVars)) { + throw new IllegalArgumentException("Input variables should be " + expectedVars + ", got " + inputVars); + } + } + + @Override + public List inputKeys() { + return List.of(QUESTION_KEY); + } + + @Override + public List outputKeys() { + return List.of(OUTPUT_KEY); + } + + @Override + public Map innerCall(Map inputs) { + var question = inputs.get(QUESTION_KEY); + String apiUrl = apiRequestChain.predict(Map.of(QUESTION_KEY, question, API_DOCS, apiDocs)); + apiUrl = apiUrl.strip(); + + String apiResponse = requestsWrapper.get(apiUrl); + String answer = apiAnswerChain.predict(Map.of(QUESTION_KEY, question, API_DOCS, apiDocs, + "api_url", apiUrl, "api_response", apiResponse)); + return Map.of(OUTPUT_KEY, answer); + } + + public static ApiChain fromLlmAndApiDocs(BaseLanguageModel llm, String apiDocs) { + return fromLlmAndApiDocs(llm, apiDocs, null, API_URL_PROMPT, API_RESPONSE_PROMPT); + } + + public static ApiChain fromLlmAndApiDocs(BaseLanguageModel llm, String apiDocs, Map headers, + BasePromptTemplate apiUrlPrompt, BasePromptTemplate apiResponsePrompt) { + LLMChain getRequestChain = new LLMChain(llm, apiUrlPrompt); + TextRequestsWrapper requestsWrapper = new TextRequestsWrapper(headers); + LLMChain getAnswerChain = new LLMChain(llm, apiResponsePrompt); + return new ApiChain(getRequestChain, getAnswerChain, requestsWrapper, apiDocs); + } + + @Override + public String chainType() { + return "api_chain"; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/api/meteo/OpenMeteoDocs.java b/langchain-core/src/main/java/com/hw/langchain/chains/api/meteo/OpenMeteoDocs.java new file mode 100644 index 000000000..d2b5565e1 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/api/meteo/OpenMeteoDocs.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.api.meteo; + +/** + * @author HamaWhite + */ +public class OpenMeteoDocs { + + private OpenMeteoDocs() { + } + + public static final String OPEN_METEO_DOCS = + """ + BASE URL: https://api.open-meteo.com/ + + API Documentation + The API endpoint /v1/forecast accepts a geographical coordinate, a list of weather variables and responds with a JSON hourly weather forecast for 7 days. Time always starts at 0:00 today and contains 168 hours. All URL parameters are listed below: + + Parameter Format Required Default Description + latitude, longitude Floating point Yes Geographical WGS84 coordinate of the location + hourly String array No A list of weather variables which should be returned. Values can be comma separated, or multiple &hourly= parameter in the URL can be used. + daily String array No A list of daily weather variable aggregations which should be returned. Values can be comma separated, or multiple &daily= parameter in the URL can be used. If daily weather variables are specified, parameter timezone is required. + current_weather Bool No false Include current weather conditions in the JSON output. + temperature_unit String No celsius If fahrenheit is set, all temperature values are converted to Fahrenheit. + windspeed_unit String No kmh Other wind speed speed units: ms, mph and kn + precipitation_unit String No mm Other precipitation amount units: inch + timeformat String No iso8601 If format unixtime is selected, all time values are returned in UNIX epoch time in seconds. Please note that all timestamp are in GMT+0! For daily values with unix timestamps, please apply utc_offset_seconds again to get the correct date. + timezone String No GMT If timezone is set, all timestamps are returned as local-time and data is returned starting at 00:00 local-time. Any time zone name from the time zone database is supported. If auto is set as a time zone, the coordinates will be automatically resolved to the local time zone. + past_days Integer (0-2) No 0 If past_days is set, yesterday or the day before yesterday data are also returned. + start_date + end_date String (yyyy-mm-dd) No The time interval to get weather data. A day must be specified as an ISO8601 date (e.g. 2022-06-30). + models String array No auto Manually select one or more weather models. Per default, the best suitable weather models will be combined. + + Hourly Parameter Definition + The parameter &hourly= accepts the following values. Most weather variables are given as an instantaneous value for the indicated hour. Some variables like precipitation are calculated from the preceding hour as an average or sum. + + Variable Valid time Unit Description + temperature_2m Instant °C (°F) Air temperature at 2 meters above ground + snowfall Preceding hour sum cm (inch) Snowfall amount of the preceding hour in centimeters. For the water equivalent in millimeter, divide by 7. E.g. 7 cm snow = 10 mm precipitation water equivalent + rain Preceding hour sum mm (inch) Rain from large scale weather systems of the preceding hour in millimeter + showers Preceding hour sum mm (inch) Showers from convective precipitation in millimeters from the preceding hour + weathercode Instant WMO code Weather condition as a numeric code. Follow WMO weather interpretation codes. See table below for details. + snow_depth Instant meters Snow depth on the ground + freezinglevel_height Instant meters Altitude above sea level of the 0°C level + visibility Instant meters Viewing distance in meters. Influenced by low clouds, humidity and aerosols. Maximum visibility is approximately 24 km."""; +} \ No newline at end of file diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/api/prompt/Prompt.java b/langchain-core/src/main/java/com/hw/langchain/chains/api/prompt/Prompt.java new file mode 100644 index 000000000..ff04cdf55 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/api/prompt/Prompt.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.api.prompt; + +import com.hw.langchain.prompts.prompt.PromptTemplate; + +import java.util.List; + +/** + * @author HamaWhite + */ +public class Prompt { + + private Prompt() { + } + + private static final String API_URL_PROMPT_TEMPLATE = + """ + You are given the below API Documentation: + {api_docs} + Using this documentation, generate the full API url to call for answering the user question. + You should build the API url in order to get a response that is as short as possible, while still getting the necessary information to answer the question. Pay attention to deliberately exclude any unnecessary pieces of data in the API call. + + Question:{question} + API url:"""; + + public static final PromptTemplate API_URL_PROMPT = + new PromptTemplate(List.of("api_docs", "question"), API_URL_PROMPT_TEMPLATE); + + private static final String API_RESPONSE_PROMPT_TEMPLATE = + API_URL_PROMPT_TEMPLATE + + """ + {api_url} + + Here is the response from the API: + + {api_response} + + Summarize this response to answer the original question. + + Summary:"""; + + public static final PromptTemplate API_RESPONSE_PROMPT = new PromptTemplate( + List.of("api_docs", "question", "api_url", "api_response"), API_RESPONSE_PROMPT_TEMPLATE); + +} diff --git a/langchain-core/src/main/java/com/hw/langchain/requests/Requests.java b/langchain-core/src/main/java/com/hw/langchain/requests/Requests.java new file mode 100644 index 000000000..c0ad1e464 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/requests/Requests.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.requests; + +import com.google.gson.Gson; + +import okhttp3.*; + +import java.io.IOException; +import java.util.Map; + +/** + * Wrapper around requests to handle auth and async. + * The main purpose of this wrapper is to handle authentication (by saving headers) + * and enable easy async methods on the same base object. + * + * @author HamaWhite + */ +public class Requests { + + private final Map headers; + + private final OkHttpClient client; + + public Requests(Map headers) { + this.headers = headers; + this.client = new OkHttpClient(); + } + + private Request buildRequest(String url, RequestBody body, String method) { + Request.Builder builder = new Request.Builder() + .url(url); + + if (headers != null) { + builder.headers(Headers.of(headers)); + } + builder.method(method, body); + return builder.build(); + } + + private Response executeRequest(Request request) throws IOException { + return client.newCall(request).execute(); + } + + public Response sendRequest(String url, String method, Map data) throws IOException { + RequestBody body = null; + + if (data != null) { + MediaType mediaType = MediaType.parse("application/json"); + String jsonBody = new Gson().toJson(data); + body = RequestBody.create(mediaType, jsonBody); + } + + Request request = buildRequest(url, body, method); + return executeRequest(request); + } + + public Response get(String url) throws IOException { + return sendRequest(url, "GET", null); + } + + public Response post(String url, Map data) throws IOException { + return sendRequest(url, "POST", data); + } + + public Response patch(String url, Map data) throws IOException { + return sendRequest(url, "PATCH", data); + } + + public Response put(String url, Map data) throws IOException { + return sendRequest(url, "PUT", data); + } + + public Response delete(String url) throws IOException { + return sendRequest(url, "DELETE", null); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/requests/TextRequestsWrapper.java b/langchain-core/src/main/java/com/hw/langchain/requests/TextRequestsWrapper.java new file mode 100644 index 000000000..0237543de --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/requests/TextRequestsWrapper.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.requests; + +import com.hw.langchain.exception.LangChainException; + +import okhttp3.Response; +import okhttp3.ResponseBody; + +import java.io.IOException; +import java.util.Map; + +/** + * Lightweight wrapper around requests library. + * The main purpose of this wrapper is to always return a text output. + * + * @author HamaWhite + */ +public class TextRequestsWrapper { + + private final Map headers; + + public TextRequestsWrapper(Map headers) { + this.headers = headers; + } + + private String performRequest(Requests requests, String url, String method, Map data) { + try (Response response = requests.sendRequest(url, method, data)) { + if (response.isSuccessful()) { + ResponseBody responseBody = response.body(); + if (responseBody != null) { + return responseBody.string(); + } + } + } catch (IOException e) { + throw new LangChainException("An error occurred while performing " + method + " request.", e); + } + return null; + } + + public String get(String url) { + Requests requests = getRequests(); + return performRequest(requests, url, "GET", null); + } + + public String post(String url, Map data) { + Requests requests = getRequests(); + return performRequest(requests, url, "POST", data); + } + + public String patch(String url, Map data) { + Requests requests = getRequests(); + return performRequest(requests, url, "PATCH", data); + } + + public String put(String url, Map data) { + Requests requests = getRequests(); + return performRequest(requests, url, "PUT", data); + } + + public String delete(String url) { + Requests requests = getRequests(); + return performRequest(requests, url, "DELETE", null); + } + + private Requests getRequests() { + return new Requests(headers); + } +} diff --git a/langchain-core/src/test/java/com/hw/langchain/chains/api/base/ApiChainTest.java b/langchain-core/src/test/java/com/hw/langchain/chains/api/base/ApiChainTest.java new file mode 100644 index 000000000..f646065d0 --- /dev/null +++ b/langchain-core/src/test/java/com/hw/langchain/chains/api/base/ApiChainTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.api.base; + +import com.hw.langchain.llms.openai.OpenAI; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.hw.langchain.chains.api.meteo.OpenMeteoDocs.OPEN_METEO_DOCS; +import static org.junit.jupiter.api.Assertions.*; + +/** + * API chains + * + * @author HamaWhite + */ +@Disabled("Test requires costly OpenAI calls, can be run manually.") +class ApiChainTest { + + private static final Logger LOG = LoggerFactory.getLogger(ApiChainTest.class); + + @Test + void testApiChain() { + var llm = OpenAI.builder().temperature(0).build().init(); + + var chain = ApiChain.fromLlmAndApiDocs(llm, OPEN_METEO_DOCS); + var result = chain.run("What is the weather like right now in Hangzhou, China in degrees Fahrenheit?"); + + // The current temperature in Hangzhou, China is 79.8°F with a windSpeed of 7.4 km/h and a wind direction of + // 133°. + LOG.info("result: \n{}", result); + assertNotNull(result, "result should not be null"); + } + +} \ No newline at end of file diff --git a/langchain-examples/src/main/java/com/hw/langchain/examples/chains/ApiChainExample.java b/langchain-examples/src/main/java/com/hw/langchain/examples/chains/ApiChainExample.java new file mode 100644 index 000000000..45b34ed56 --- /dev/null +++ b/langchain-examples/src/main/java/com/hw/langchain/examples/chains/ApiChainExample.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.examples.chains; + +import com.hw.langchain.chains.api.base.ApiChain; +import com.hw.langchain.examples.runner.RunnableExample; +import com.hw.langchain.llms.openai.OpenAI; + +import static com.hw.langchain.chains.api.meteo.OpenMeteoDocs.OPEN_METEO_DOCS; +import static com.hw.langchain.examples.utils.PrintUtils.println; + +/** + * API chains + * + * @author HamaWhite + */ +@RunnableExample +public class ApiChainExample { + + public static void main(String[] args) { + var llm = OpenAI.builder().temperature(0).build().init(); + + var chain = ApiChain.fromLlmAndApiDocs(llm, OPEN_METEO_DOCS); + var result = chain.run("What is the weather like right now in Hangzhou, China in degrees Fahrenheit?"); + + // The current temperature in Hangzhou, China is 79.8°F with a windSpeed of 7.4 km/h and a wind direction of + // 133°. + println(result); + } +}