-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
429 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
package huggingfacehub | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
const apiEndpoint = "https://api-inference.huggingface.co" | ||
|
||
type Client struct { | ||
apiToken string | ||
repoID string | ||
task string | ||
} | ||
|
||
func New(apiToken, repoID, task string) *Client { | ||
return &Client{ | ||
apiToken: apiToken, | ||
repoID: repoID, | ||
task: task, | ||
} | ||
} | ||
|
||
func (hf *Client) Summarization(ctx context.Context, req *SummarizationRequest) (*SummarizationResponse, error) { | ||
reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID) | ||
|
||
res, err := hf.doRequest(ctx, http.MethodPost, reqURL, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer res.Body.Close() | ||
|
||
body, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if res.StatusCode != http.StatusOK { | ||
errResp := ErrorResponse{} | ||
if err := json.Unmarshal(body, &errResp); err != nil { | ||
return nil, err | ||
} | ||
|
||
return nil, fmt.Errorf("hugging faces error: %s", errResp.Error) | ||
} | ||
|
||
summarizationResponse := SummarizationResponse{} | ||
if err := json.Unmarshal(body, &summarizationResponse); err != nil { | ||
return nil, err | ||
} | ||
|
||
return &summarizationResponse, nil | ||
} | ||
|
||
func (hf *Client) TextGeneration(ctx context.Context, req *TextGenerationRequest) (TextGenerationResponse, error) { | ||
reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID) | ||
|
||
res, err := hf.doRequest(ctx, http.MethodPost, reqURL, req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer res.Body.Close() | ||
|
||
body, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if res.StatusCode != http.StatusOK { | ||
errResp := ErrorResponse{} | ||
if err := json.Unmarshal(body, &errResp); err != nil { | ||
return nil, err | ||
} | ||
|
||
return nil, fmt.Errorf("hugging faces error: %s", errResp.Error) | ||
} | ||
|
||
textGenerations := TextGenerationResponse{} | ||
if err := json.Unmarshal(body, &textGenerations); err != nil { | ||
return nil, err | ||
} | ||
|
||
return textGenerations, nil | ||
} | ||
|
||
func (hf *Client) Text2TextGeneration(ctx context.Context, req *Text2TextGenerationRequest) (Text2TextGenerationResponse, error) { | ||
reqURL := fmt.Sprintf("%s/pipeline/%s/%s", apiEndpoint, hf.task, hf.repoID) | ||
|
||
res, err := hf.doRequest(ctx, http.MethodPost, reqURL, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer res.Body.Close() | ||
|
||
body, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if res.StatusCode != http.StatusOK { | ||
errResp := ErrorResponse{} | ||
if err := json.Unmarshal(body, &errResp); err != nil { | ||
return nil, err | ||
} | ||
|
||
return nil, fmt.Errorf("hugging faces error: %s", errResp.Error) | ||
} | ||
|
||
text2TextGenerationResponse := Text2TextGenerationResponse{} | ||
if err := json.Unmarshal(body, &text2TextGenerationResponse); err != nil { | ||
return nil, err | ||
} | ||
|
||
return text2TextGenerationResponse, nil | ||
} | ||
|
||
func (hf *Client) doRequest(ctx context.Context, method string, url string, payload any) (*http.Response, error) { | ||
body, err := json.Marshal(payload) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
httpReq, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(body)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
httpReq.Header.Set("Content-Type", "application/json") | ||
httpReq.Header.Set("Accept", "application/json") | ||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", hf.apiToken)) | ||
|
||
return http.DefaultClient.Do(httpReq) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
package huggingfacehub | ||
|
||
type Options struct { | ||
// (Default: true). There is a cache layer on the inference API to speedup | ||
// requests we have already seen. Most models can use those results as is | ||
// as models are deterministic (meaning the results will be the same anyway). | ||
// However if you use a non deterministic model, you can set this parameter | ||
// to prevent the caching mechanism from being used resulting in a real new query. | ||
UseCache *bool `json:"use_cache,omitempty"` | ||
|
||
// (Default: false) If the model is not ready, wait for it instead of receiving 503. | ||
// It limits the number of requests required to get your inference done. It is advised | ||
// to only set this flag to true after receiving a 503 error as it will limit hanging | ||
// in your application to known places. | ||
WaitForModel *bool `json:"wait_for_model,omitempty"` | ||
} | ||
|
||
type SummarizationParameters struct { | ||
// (Default: None). Integer to define the minimum length in tokens of the output summary. | ||
MinLength *int `json:"min_length,omitempty"` | ||
|
||
// (Default: None). Integer to define the maximum length in tokens of the output summary. | ||
MaxLength *int `json:"max_length,omitempty"` | ||
|
||
// (Default: None). Integer to define the top tokens considered within the sample operation to create | ||
// new text. | ||
TopK *int `json:"top_k,omitempty"` | ||
|
||
// (Default: None). Float to define the tokens that are within the sample` operation of text generation. | ||
// Add tokens in the sample for more probable to least probable until the sum of the probabilities is | ||
// greater than top_p. | ||
TopP *float64 `json:"top_p,omitempty"` | ||
|
||
// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, | ||
// 0 mens top_k=1, 100.0 is getting closer to uniform probability. | ||
Temperature *float64 `json:"temperature,omitempty"` | ||
|
||
// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized | ||
// to not be picked in successive generation passes. | ||
RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"` | ||
|
||
// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. | ||
// Network can cause some overhead so it will be a soft limit. | ||
MaxTime *float64 `json:"maxtime,omitempty"` | ||
} | ||
|
||
type SummarizationRequest struct { | ||
// String to be summarized | ||
Inputs string `json:"inputs"` | ||
Parameters SummarizationParameters `json:"parameters,omitempty"` | ||
Options Options `json:"options,omitempty"` | ||
} | ||
|
||
type SummarizationResponse struct { | ||
// The summarized input string | ||
SummaryText string `json:"summary_text,omitempty"` | ||
} | ||
|
||
type TextGenerationParameters struct { | ||
// (Default: None). Integer to define the top tokens considered within the sample operation to create new text. | ||
TopK *int `json:"top_k,omitempty"` | ||
|
||
// (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add | ||
// tokens in the sample for more probable to least probable until the sum of the probabilities is greater | ||
// than top_p. | ||
TopP *float64 `json:"top_p,omitempty"` | ||
|
||
// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, | ||
// 0 means top_k=1, 100.0 is getting closer to uniform probability. | ||
Temperature *float64 `json:"temperature,omitempty"` | ||
|
||
// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized | ||
// to not be picked in successive generation passes. | ||
RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` | ||
|
||
// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input | ||
// length it is a estimate of the size of generated text you want. Each new tokens slows down the request, | ||
// so look for balance between response times and length of text generated. | ||
MaxNewTokens *int `json:"max_new_tokens,omitempty"` | ||
|
||
// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. | ||
// Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens | ||
// for best results. | ||
MaxTime *float64 `json:"max_time,omitempty"` | ||
|
||
// (Default: True). Bool. If set to False, the return results will not contain the original query making it | ||
// easier for prompting. | ||
ReturnFullText *bool `json:"return_full_text,omitempty"` | ||
|
||
// (Default: 1). Integer. The number of proposition you want to be returned. | ||
NumReturnSequences *int `json:"num_return_sequences,omitempty"` | ||
} | ||
|
||
type TextGenerationRequest struct { | ||
// String to generated from | ||
Inputs string `json:"inputs"` | ||
Parameters TextGenerationParameters `json:"parameters,omitempty"` | ||
Options Options `json:"options,omitempty"` | ||
} | ||
|
||
// A list of generated texts. The length of this list is the value of | ||
// NumReturnSequences in the request. | ||
type TextGenerationResponse []struct { | ||
GeneratedText string `json:"generated_text,omitempty"` | ||
} | ||
|
||
type Text2TextGenerationParameters struct { | ||
// (Default: None). Integer to define the top tokens considered within the sample operation to create new text. | ||
TopK *int `json:"top_k,omitempty"` | ||
|
||
// (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add | ||
// tokens in the sample for more probable to least probable until the sum of the probabilities is greater | ||
// than top_p. | ||
TopP *float64 `json:"top_p,omitempty"` | ||
|
||
// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, | ||
// 0 means top_k=1, 100.0 is getting closer to uniform probability. | ||
Temperature *float64 `json:"temperature,omitempty"` | ||
|
||
// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized | ||
// to not be picked in successive generation passes. | ||
RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` | ||
|
||
// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input | ||
// length it is a estimate of the size of generated text you want. Each new tokens slows down the request, | ||
// so look for balance between response times and length of text generated. | ||
MaxNewTokens *int `json:"max_new_tokens,omitempty"` | ||
|
||
// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. | ||
// Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens | ||
// for best results. | ||
MaxTime *float64 `json:"max_time,omitempty"` | ||
|
||
// (Default: True). Bool. If set to False, the return results will not contain the original query making it | ||
// easier for prompting. | ||
ReturnFullText *bool `json:"return_full_text,omitempty"` | ||
|
||
// (Default: 1). Integer. The number of proposition you want to be returned. | ||
NumReturnSequences *int `json:"num_return_sequences,omitempty"` | ||
} | ||
|
||
type Text2TextGenerationRequest struct { | ||
// String to generated from | ||
Inputs string `json:"inputs"` | ||
Parameters Text2TextGenerationParameters `json:"parameters,omitempty"` | ||
Options Options `json:"options,omitempty"` | ||
} | ||
|
||
type Text2TextGenerationResponse []struct { | ||
GeneratedText string `json:"generated_text,omitempty"` | ||
} | ||
|
||
type ErrorResponse struct { | ||
Error string `json:"error"` | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.