Skip to content

Commit

Permalink
Add hugging face llm
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 20, 2023
1 parent 7840eaf commit 13afc63
Show file tree
Hide file tree
Showing 8 changed files with 429 additions and 17 deletions.
2 changes: 1 addition & 1 deletion agent/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (e Executor) Memory() schema.Memory {
}

func (e Executor) Type() string {
return "Executor"
return "AgentExecutor"
}

func (e Executor) Verbose() bool {
Expand Down
15 changes: 9 additions & 6 deletions evaluation/context_qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,7 @@ type ContextQAEvalChain struct {
}

func NewContextQAEvalChain(llm schema.LLM, optFns ...func(o *ContextQAEvalChainOptions)) (*ContextQAEvalChain, error) {
contextQAEvalPrompt, err := prompt.NewTemplate(contextQAEvalTemplate)
if err != nil {
return nil, err
}

opts := ContextQAEvalChainOptions{
Prompt: contextQAEvalPrompt,
QuestionKey: "query",
ContextKey: "context",
PredictionKey: "result",
Expand All @@ -60,6 +54,15 @@ func NewContextQAEvalChain(llm schema.LLM, optFns ...func(o *ContextQAEvalChainO
fn(&opts)
}

if opts.Prompt == nil {
contextQAEvalPrompt, err := prompt.NewTemplate(contextQAEvalTemplate)
if err != nil {
return nil, err
}

opts.Prompt = contextQAEvalPrompt
}

llmChain, err := chain.NewLLMChain(llm, opts.Prompt)
if err != nil {
return nil, err
Expand Down
15 changes: 9 additions & 6 deletions evaluation/cot_qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,7 @@ type COTQAEvalChain struct {
}

func NewCOTQAEvalChain(llm schema.LLM, optFns ...func(o *COTQAEvalChainOptions)) (*COTQAEvalChain, error) {
cotQAEvalPrompt, err := prompt.NewTemplate(cotQAEvalTemplate)
if err != nil {
return nil, err
}

opts := COTQAEvalChainOptions{
Prompt: cotQAEvalPrompt,
QuestionKey: "query",
ContextKey: "context",
PredictionKey: "result",
Expand All @@ -52,6 +46,15 @@ func NewCOTQAEvalChain(llm schema.LLM, optFns ...func(o *COTQAEvalChainOptions))
fn(&opts)
}

if opts.Prompt == nil {
cotQAEvalPrompt, err := prompt.NewTemplate(cotQAEvalTemplate)
if err != nil {
return nil, err
}

opts.Prompt = cotQAEvalPrompt
}

contextQAEvalChain, err := NewContextQAEvalChain(llm, func(o *ContextQAEvalChainOptions) {
o.Prompt = opts.Prompt
o.QuestionKey = opts.QuestionKey
Expand Down
137 changes: 137 additions & 0 deletions integration/huggingfacehub/hugging_face_hub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package huggingfacehub

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)

const apiEndpoint = "https://api-inference.huggingface.co"

type Client struct {
apiToken string
repoID string
task string
}

func New(apiToken, repoID, task string) *Client {
return &Client{
apiToken: apiToken,
repoID: repoID,
task: task,
}
}

func (hf *Client) Summarization(ctx context.Context, req *SummarizationRequest) (*SummarizationResponse, error) {
reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID)

res, err := hf.doRequest(ctx, http.MethodPost, reqURL, nil)
if err != nil {
return nil, err
}
defer res.Body.Close()

body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}

if res.StatusCode != http.StatusOK {
errResp := ErrorResponse{}
if err := json.Unmarshal(body, &errResp); err != nil {
return nil, err
}

return nil, fmt.Errorf("hugging faces error: %s", errResp.Error)
}

summarizationResponse := SummarizationResponse{}
if err := json.Unmarshal(body, &summarizationResponse); err != nil {
return nil, err
}

return &summarizationResponse, nil
}

func (hf *Client) TextGeneration(ctx context.Context, req *TextGenerationRequest) (TextGenerationResponse, error) {
reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID)

res, err := hf.doRequest(ctx, http.MethodPost, reqURL, req)
if err != nil {
return nil, err
}
defer res.Body.Close()

body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}

if res.StatusCode != http.StatusOK {
errResp := ErrorResponse{}
if err := json.Unmarshal(body, &errResp); err != nil {
return nil, err
}

return nil, fmt.Errorf("hugging faces error: %s", errResp.Error)
}

textGenerations := TextGenerationResponse{}
if err := json.Unmarshal(body, &textGenerations); err != nil {
return nil, err
}

return textGenerations, nil
}

func (hf *Client) Text2TextGeneration(ctx context.Context, req *Text2TextGenerationRequest) (Text2TextGenerationResponse, error) {
reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID)

res, err := hf.doRequest(ctx, http.MethodPost, reqURL, nil)
if err != nil {
return nil, err
}
defer res.Body.Close()

body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}

if res.StatusCode != http.StatusOK {
errResp := ErrorResponse{}
if err := json.Unmarshal(body, &errResp); err != nil {
return nil, err
}

return nil, fmt.Errorf("hugging faces error: %s", errResp.Error)
}

text2TextGenerationResponse := Text2TextGenerationResponse{}
if err := json.Unmarshal(body, &text2TextGenerationResponse); err != nil {
return nil, err
}

return text2TextGenerationResponse, nil
}

func (hf *Client) doRequest(ctx context.Context, method string, url string, payload any) (*http.Response, error) {
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}

httpReq, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(body))
if err != nil {
return nil, err
}

httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", hf.apiToken))

return http.DefaultClient.Do(httpReq)
}
155 changes: 155 additions & 0 deletions integration/huggingfacehub/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package huggingfacehub

type Options struct {
// (Default: true). There is a cache layer on the inference API to speedup
// requests we have already seen. Most models can use those results as is
// as models are deterministic (meaning the results will be the same anyway).
// However if you use a non deterministic model, you can set this parameter
// to prevent the caching mechanism from being used resulting in a real new query.
UseCache *bool `json:"use_cache,omitempty"`

// (Default: false) If the model is not ready, wait for it instead of receiving 503.
// It limits the number of requests required to get your inference done. It is advised
// to only set this flag to true after receiving a 503 error as it will limit hanging
// in your application to known places.
WaitForModel *bool `json:"wait_for_model,omitempty"`
}

type SummarizationParameters struct {
// (Default: None). Integer to define the minimum length in tokens of the output summary.
MinLength *int `json:"min_length,omitempty"`

// (Default: None). Integer to define the maximum length in tokens of the output summary.
MaxLength *int `json:"max_length,omitempty"`

// (Default: None). Integer to define the top tokens considered within the sample operation to create
// new text.
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation.
// Add tokens in the sample for more probable to least probable until the sum of the probabilities is
// greater than top_p.
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 mens top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit.
MaxTime *float64 `json:"maxtime,omitempty"`
}

type SummarizationRequest struct {
// String to be summarized
Inputs string `json:"inputs"`
Parameters SummarizationParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
}

type SummarizationResponse struct {
// The summarized input string
SummaryText string `json:"summary_text,omitempty"`
}

type TextGenerationParameters struct {
// (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add
// tokens in the sample for more probable to least probable until the sum of the probabilities is greater
// than top_p.
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 means top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"`

// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input
// length it is a estimate of the size of generated text you want. Each new tokens slows down the request,
// so look for balance between response times and length of text generated.
MaxNewTokens *int `json:"max_new_tokens,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens
// for best results.
MaxTime *float64 `json:"max_time,omitempty"`

// (Default: True). Bool. If set to False, the return results will not contain the original query making it
// easier for prompting.
ReturnFullText *bool `json:"return_full_text,omitempty"`

// (Default: 1). Integer. The number of proposition you want to be returned.
NumReturnSequences *int `json:"num_return_sequences,omitempty"`
}

type TextGenerationRequest struct {
// String to generated from
Inputs string `json:"inputs"`
Parameters TextGenerationParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
}

// A list of generated texts. The length of this list is the value of
// NumReturnSequences in the request.
type TextGenerationResponse []struct {
GeneratedText string `json:"generated_text,omitempty"`
}

type Text2TextGenerationParameters struct {
// (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add
// tokens in the sample for more probable to least probable until the sum of the probabilities is greater
// than top_p.
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 means top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"`

// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input
// length it is a estimate of the size of generated text you want. Each new tokens slows down the request,
// so look for balance between response times and length of text generated.
MaxNewTokens *int `json:"max_new_tokens,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens
// for best results.
MaxTime *float64 `json:"max_time,omitempty"`

// (Default: True). Bool. If set to False, the return results will not contain the original query making it
// easier for prompting.
ReturnFullText *bool `json:"return_full_text,omitempty"`

// (Default: 1). Integer. The number of proposition you want to be returned.
NumReturnSequences *int `json:"num_return_sequences,omitempty"`
}

type Text2TextGenerationRequest struct {
// String to generated from
Inputs string `json:"inputs"`
Parameters Text2TextGenerationParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
}

type Text2TextGenerationResponse []struct {
GeneratedText string `json:"generated_text,omitempty"`
}

type ErrorResponse struct {
Error string `json:"error"`
}
2 changes: 1 addition & 1 deletion integration/pinecone/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (p *RestClient) Fetch(ctx context.Context, req *FetchRequest) (*FetchRespon

reqURL := fmt.Sprintf("https://%s/vectors/fetch?%s", p.target, params.Encode())

res, err := p.doRequest(ctx, http.MethodPost, reqURL, req)
res, err := p.doRequest(ctx, http.MethodGet, reqURL, req)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 13afc63

Please sign in to comment.