Skip to content

Commit

Permalink
Propagate HubApi configuration (huggingface#62)
Browse files Browse the repository at this point in the history
* propagated HubApi

* review changes
  • Loading branch information
jkrukowski authored Mar 14, 2024
1 parent 95194e6 commit 24605a8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
19 changes: 14 additions & 5 deletions Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,12 @@ public class LanguageModelConfigurationFromHub {

private var configPromise: Task<Configurations, Error>? = nil

public init(modelName: String) {
public init(
modelName: String,
hubApi: HubApi = .shared
) {
self.configPromise = Task.init {
return try await self.loadConfig(modelName: modelName)
return try await self.loadConfig(modelName: modelName, hubApi: hubApi)
}
}

Expand Down Expand Up @@ -161,8 +164,10 @@ public class LanguageModelConfigurationFromHub {
}
}

func loadConfig(modelName: String, hfToken: String? = nil) async throws -> Configurations {
let hubApi = HubApi(hfToken: hfToken)
func loadConfig(
modelName: String,
hubApi: HubApi = .shared
) async throws -> Configurations {
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
let repo = Hub.Repo(id: modelName)
try await hubApi.snapshot(from: repo, matching: filesToDownload)
Expand All @@ -172,7 +177,11 @@ public class LanguageModelConfigurationFromHub {
let tokenizerConfig = try? hubApi.configuration(from: "tokenizer_config.json", in: repo)
let tokenizerVocab = try hubApi.configuration(from: "tokenizer.json", in: repo)

let configs = Configurations(modelConfig: modelConfig, tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerVocab)
let configs = Configurations(
modelConfig: modelConfig,
tokenizerConfig: tokenizerConfig,
tokenizerData: tokenizerVocab
)
return configs
}

Expand Down
30 changes: 20 additions & 10 deletions Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,29 @@
import Foundation

public struct HubApi {
var downloadBase: URL
var hfToken: String?
var endpoint: String
public let downloadBase: URL
public let hfToken: String?
public let endpoint: String

public typealias RepoType = Hub.RepoType
public typealias Repo = Hub.Repo

public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co") {
if downloadBase == nil {
public init(
downloadBase: URL? = nil,
hfToken: String? = nil,
endpoint: String = "https://huggingface.co"
) {
self.hfToken = hfToken
if let downloadBase {
self.downloadBase = downloadBase
} else {
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
self.downloadBase = documents.appending(component: "huggingface")
} else {
self.downloadBase = downloadBase!
}
self.hfToken = hfToken
self.endpoint = endpoint
}

static let shared = HubApi()
public static let shared = HubApi()
}

/// File retrieval
Expand Down Expand Up @@ -179,7 +183,13 @@ public extension HubApi {
let repoDestination = localRepoLocation(repo)
for filename in filenames {
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
let downloader = HubFileDownloader(repo: repo, repoDestination: repoDestination, relativeFilename: filename, hfToken: hfToken, endpoint: endpoint)
let downloader = HubFileDownloader(
repo: repo,
repoDestination: repoDestination,
relativeFilename: filename,
hfToken: hfToken,
endpoint: endpoint
)
try await downloader.download { fractionDownloaded in
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)
progressHandler(progress)
Expand Down
7 changes: 5 additions & 2 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,11 @@ extension AutoTokenizer {
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

public static func from(pretrained model: String) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelName: model)
public static func from(
pretrained model: String,
hubApi: HubApi = .shared
) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi)
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

Expand Down

0 comments on commit 24605a8

Please sign in to comment.