From 13afc63518d9a2eceaf41d8d0e2ec5c37873ff5e Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Tue, 20 Jun 2023 20:53:38 +0200 Subject: [PATCH] Add hugging face llm --- agent/executor.go | 2 +- evaluation/context_qa_eval_chain.go | 15 +- evaluation/cot_qa_eval_chain.go | 15 +- .../huggingfacehub/hugging_face_hub.go | 137 ++++++++++++++++ integration/huggingfacehub/types.go | 155 ++++++++++++++++++ integration/pinecone/rest.go | 2 +- model/llm/hugging_face_hub.go | 117 +++++++++++++ model/llm/huggingface.go | 3 - 8 files changed, 429 insertions(+), 17 deletions(-) create mode 100644 integration/huggingfacehub/hugging_face_hub.go create mode 100644 integration/huggingfacehub/types.go create mode 100644 model/llm/hugging_face_hub.go delete mode 100644 model/llm/huggingface.go diff --git a/agent/executor.go b/agent/executor.go index b665153..6c57a49 100644 --- a/agent/executor.go +++ b/agent/executor.go @@ -99,7 +99,7 @@ func (e Executor) Memory() schema.Memory { } func (e Executor) Type() string { - return "Executor" + return "AgentExecutor" } func (e Executor) Verbose() bool { diff --git a/evaluation/context_qa_eval_chain.go b/evaluation/context_qa_eval_chain.go index d24383e..7e30ae8 100644 --- a/evaluation/context_qa_eval_chain.go +++ b/evaluation/context_qa_eval_chain.go @@ -44,13 +44,7 @@ type ContextQAEvalChain struct { } func NewContextQAEvalChain(llm schema.LLM, optFns ...func(o *ContextQAEvalChainOptions)) (*ContextQAEvalChain, error) { - contextQAEvalPrompt, err := prompt.NewTemplate(contextQAEvalTemplate) - if err != nil { - return nil, err - } - opts := ContextQAEvalChainOptions{ - Prompt: contextQAEvalPrompt, QuestionKey: "query", ContextKey: "context", PredictionKey: "result", @@ -60,6 +54,15 @@ func NewContextQAEvalChain(llm schema.LLM, optFns ...func(o *ContextQAEvalChainO fn(&opts) } + if opts.Prompt == nil { + contextQAEvalPrompt, err := prompt.NewTemplate(contextQAEvalTemplate) + if err != nil { + return nil, err + } + + opts.Prompt = contextQAEvalPrompt + } + llmChain, err := chain.NewLLMChain(llm, opts.Prompt) if err != nil { return nil, err diff --git a/evaluation/cot_qa_eval_chain.go b/evaluation/cot_qa_eval_chain.go index 848da9d..034e84d 100644 --- a/evaluation/cot_qa_eval_chain.go +++ b/evaluation/cot_qa_eval_chain.go @@ -36,13 +36,7 @@ type COTQAEvalChain struct { } func NewCOTQAEvalChain(llm schema.LLM, optFns ...func(o *COTQAEvalChainOptions)) (*COTQAEvalChain, error) { - cotQAEvalPrompt, err := prompt.NewTemplate(cotQAEvalTemplate) - if err != nil { - return nil, err - } - opts := COTQAEvalChainOptions{ - Prompt: cotQAEvalPrompt, QuestionKey: "query", ContextKey: "context", PredictionKey: "result", @@ -52,6 +46,15 @@ func NewCOTQAEvalChain(llm schema.LLM, optFns ...func(o *COTQAEvalChainOptions)) fn(&opts) } + if opts.Prompt == nil { + cotQAEvalPrompt, err := prompt.NewTemplate(cotQAEvalTemplate) + if err != nil { + return nil, err + } + + opts.Prompt = cotQAEvalPrompt + } + contextQAEvalChain, err := NewContextQAEvalChain(llm, func(o *ContextQAEvalChainOptions) { o.Prompt = opts.Prompt o.QuestionKey = opts.QuestionKey diff --git a/integration/huggingfacehub/hugging_face_hub.go b/integration/huggingfacehub/hugging_face_hub.go new file mode 100644 index 0000000..31aa351 --- /dev/null +++ b/integration/huggingfacehub/hugging_face_hub.go @@ -0,0 +1,137 @@ +package huggingfacehub + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +const apiEndpoint = "https://api-inference.huggingface.co" + +type Client struct { + apiToken string + repoID string + task string +} + +func New(apiToken, repoID, task string) *Client { + return &Client{ + apiToken: apiToken, + repoID: repoID, + task: task, + } +} + +func (hf *Client) Summarization(ctx context.Context, req *SummarizationRequest) (*SummarizationResponse, error) { + reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID) + + res, err := hf.doRequest(ctx, http.MethodPost, reqURL, nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + if res.StatusCode != http.StatusOK { + errResp := ErrorResponse{} + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, err + } + + return nil, fmt.Errorf("hugging faces error: %s", errResp.Error) + } + + summarizationResponse := SummarizationResponse{} + if err := json.Unmarshal(body, &summarizationResponse); err != nil { + return nil, err + } + + return &summarizationResponse, nil +} + +func (hf *Client) TextGeneration(ctx context.Context, req *TextGenerationRequest) (TextGenerationResponse, error) { + reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID) + + res, err := hf.doRequest(ctx, http.MethodPost, reqURL, req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + if res.StatusCode != http.StatusOK { + errResp := ErrorResponse{} + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, err + } + + return nil, fmt.Errorf("hugging faces error: %s", errResp.Error) + } + + textGenerations := TextGenerationResponse{} + if err := json.Unmarshal(body, &textGenerations); err != nil { + return nil, err + } + + return textGenerations, nil +} + +func (hf *Client) Text2TextGeneration(ctx context.Context, req *Text2TextGenerationRequest) (Text2TextGenerationResponse, error) { + reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID) + + res, err := hf.doRequest(ctx, http.MethodPost, reqURL, nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + if res.StatusCode != http.StatusOK { + errResp := ErrorResponse{} + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, err + } + + return nil, fmt.Errorf("hugging faces error: %s", errResp.Error) + } + + text2TextGenerationResponse := Text2TextGenerationResponse{} + if err := json.Unmarshal(body, &text2TextGenerationResponse); err != nil { + return nil, err + } + + return text2TextGenerationResponse, nil +} + +func (hf *Client) doRequest(ctx context.Context, method string, url string, payload any) (*http.Response, error) { + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", hf.apiToken)) + + return http.DefaultClient.Do(httpReq) +} diff --git a/integration/huggingfacehub/types.go b/integration/huggingfacehub/types.go new file mode 100644 index 0000000..599cfc8 --- /dev/null +++ b/integration/huggingfacehub/types.go @@ -0,0 +1,155 @@ +package huggingfacehub + +type Options struct { + // (Default: true). There is a cache layer on the inference API to speedup + // requests we have already seen. Most models can use those results as is + // as models are deterministic (meaning the results will be the same anyway). + // However if you use a non deterministic model, you can set this parameter + // to prevent the caching mechanism from being used resulting in a real new query. + UseCache *bool `json:"use_cache,omitempty"` + + // (Default: false) If the model is not ready, wait for it instead of receiving 503. + // It limits the number of requests required to get your inference done. It is advised + // to only set this flag to true after receiving a 503 error as it will limit hanging + // in your application to known places. + WaitForModel *bool `json:"wait_for_model,omitempty"` +} + +type SummarizationParameters struct { + // (Default: None). Integer to define the minimum length in tokens of the output summary. + MinLength *int `json:"min_length,omitempty"` + + // (Default: None). Integer to define the maximum length in tokens of the output summary. + MaxLength *int `json:"max_length,omitempty"` + + // (Default: None). Integer to define the top tokens considered within the sample operation to create + // new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. + // Add tokens in the sample for more probable to least probable until the sum of the probabilities is + // greater than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 mens top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. + MaxTime *float64 `json:"maxtime,omitempty"` +} + +type SummarizationRequest struct { + // String to be summarized + Inputs string `json:"inputs"` + Parameters SummarizationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` +} + +type SummarizationResponse struct { + // The summarized input string + SummaryText string `json:"summary_text,omitempty"` +} + +type TextGenerationParameters struct { + // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add + // tokens in the sample for more probable to least probable until the sum of the probabilities is greater + // than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 means top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + + // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input + // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, + // so look for balance between response times and length of text generated. + MaxNewTokens *int `json:"max_new_tokens,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens + // for best results. + MaxTime *float64 `json:"max_time,omitempty"` + + // (Default: True). Bool. If set to False, the return results will not contain the original query making it + // easier for prompting. + ReturnFullText *bool `json:"return_full_text,omitempty"` + + // (Default: 1). Integer. The number of proposition you want to be returned. + NumReturnSequences *int `json:"num_return_sequences,omitempty"` +} + +type TextGenerationRequest struct { + // String to generated from + Inputs string `json:"inputs"` + Parameters TextGenerationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` +} + +// A list of generated texts. The length of this list is the value of +// NumReturnSequences in the request. +type TextGenerationResponse []struct { + GeneratedText string `json:"generated_text,omitempty"` +} + +type Text2TextGenerationParameters struct { + // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add + // tokens in the sample for more probable to least probable until the sum of the probabilities is greater + // than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 means top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + + // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input + // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, + // so look for balance between response times and length of text generated. + MaxNewTokens *int `json:"max_new_tokens,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens + // for best results. + MaxTime *float64 `json:"max_time,omitempty"` + + // (Default: True). Bool. If set to False, the return results will not contain the original query making it + // easier for prompting. + ReturnFullText *bool `json:"return_full_text,omitempty"` + + // (Default: 1). Integer. The number of proposition you want to be returned. + NumReturnSequences *int `json:"num_return_sequences,omitempty"` +} + +type Text2TextGenerationRequest struct { + // String to generated from + Inputs string `json:"inputs"` + Parameters Text2TextGenerationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` +} + +type Text2TextGenerationResponse []struct { + GeneratedText string `json:"generated_text,omitempty"` +} + +type ErrorResponse struct { + Error string `json:"error"` +} diff --git a/integration/pinecone/rest.go b/integration/pinecone/rest.go index d6d43f1..a41d086 100644 --- a/integration/pinecone/rest.go +++ b/integration/pinecone/rest.go @@ -59,7 +59,7 @@ func (p *RestClient) Fetch(ctx context.Context, req *FetchRequest) (*FetchRespon reqURL := fmt.Sprintf("https://%s/vectors/fetch?%s", p.target, params.Encode()) - res, err := p.doRequest(ctx, http.MethodPost, reqURL, req) + res, err := p.doRequest(ctx, http.MethodGet, reqURL, req) if err != nil { return nil, err } diff --git a/model/llm/hugging_face_hub.go b/model/llm/hugging_face_hub.go new file mode 100644 index 0000000..047fa0e --- /dev/null +++ b/model/llm/hugging_face_hub.go @@ -0,0 +1,117 @@ +package llm + +import ( + "context" + "fmt" + + "github.com/hupe1980/golc" + "github.com/hupe1980/golc/integration/huggingfacehub" + "github.com/hupe1980/golc/schema" +) + +// Compile time check to ensure HuggingFaceHub satisfies the LLM interface. +var _ schema.LLM = (*HuggingFaceHub)(nil) + +type HuggingFaceHubOptions struct { + *schema.CallbackOptions + // Model name to use. + RepoID string + Task string +} + +type HuggingFaceHub struct { + schema.Tokenizer + client *huggingfacehub.Client + opts HuggingFaceHubOptions +} + +func NewHuggingFaceHub(apiToken string, optFns ...func(o *HuggingFaceHubOptions)) (*HuggingFaceHub, error) { + opts := HuggingFaceHubOptions{ + CallbackOptions: &schema.CallbackOptions{ + Verbose: golc.Verbose, + }, + RepoID: "gpt2", + Task: "text-generation", + } + + for _, fn := range optFns { + fn(&opts) + } + + return &HuggingFaceHub{ + client: huggingfacehub.New(apiToken, opts.RepoID, opts.Task), + opts: opts, + }, nil +} + +func (l *HuggingFaceHub) Generate(ctx context.Context, prompts []string, stop []string) (*schema.LLMResult, error) { + var ( + text string + err error + ) + + if l.opts.Task == "text-generation" { + text, err = l.textGeneration(ctx, prompts[0]) + } else if l.opts.Task == "text2text-generation" { + text, err = l.text2textGeneration(ctx, prompts[0]) + } else if l.opts.Task == "summarization" { + text, err = l.summarization(ctx, prompts[0]) + } else { + err = fmt.Errorf("unknown task: %s", l.opts.Task) + } + + if err != nil { + return nil, err + } + + return &schema.LLMResult{ + Generations: [][]*schema.Generation{{&schema.Generation{Text: text}}}, + LLMOutput: map[string]any{}, + }, nil +} + +func (l *HuggingFaceHub) textGeneration(ctx context.Context, input string) (string, error) { + res, err := l.client.TextGeneration(ctx, &huggingfacehub.TextGenerationRequest{ + Inputs: input, + }) + if err != nil { + return "", err + } + + // Text generation return includes the starter text. + return res[0].GeneratedText[len(input):], nil +} + +func (l *HuggingFaceHub) text2textGeneration(ctx context.Context, input string) (string, error) { + res, err := l.client.Text2TextGeneration(ctx, &huggingfacehub.Text2TextGenerationRequest{ + Inputs: input, + }) + if err != nil { + return "", err + } + + return res[0].GeneratedText, nil +} + +func (l *HuggingFaceHub) summarization(ctx context.Context, input string) (string, error) { + res, err := l.client.Summarization(ctx, &huggingfacehub.SummarizationRequest{ + Inputs: input, + }) + if err != nil { + return "", err + } + + return res.SummaryText, nil +} + +func (l *HuggingFaceHub) Type() string { + return "HuggingFaceHub" +} + +func (l *HuggingFaceHub) Verbose() bool { + return l.opts.CallbackOptions.Verbose +} + +func (l *HuggingFaceHub) Callbacks() []schema.Callback { + return l.opts.CallbackOptions.Callbacks +} diff --git a/model/llm/huggingface.go b/model/llm/huggingface.go deleted file mode 100644 index 74d04be..0000000 --- a/model/llm/huggingface.go +++ /dev/null @@ -1,3 +0,0 @@ -package llm - -type HuggingFaceOptions struct{}