diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index c5649ce..8789ebd 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -112,9 +112,12 @@ public class LanguageModelConfigurationFromHub { private var configPromise: Task? = nil - public init(modelName: String) { + public init( + modelName: String, + hubApi: HubApi = .shared + ) { self.configPromise = Task.init { - return try await self.loadConfig(modelName: modelName) + return try await self.loadConfig(modelName: modelName, hubApi: hubApi) } } @@ -161,8 +164,10 @@ public class LanguageModelConfigurationFromHub { } } - func loadConfig(modelName: String, hfToken: String? = nil) async throws -> Configurations { - let hubApi = HubApi(hfToken: hfToken) + func loadConfig( + modelName: String, + hubApi: HubApi = .shared + ) async throws -> Configurations { let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"] let repo = Hub.Repo(id: modelName) try await hubApi.snapshot(from: repo, matching: filesToDownload) @@ -172,7 +177,11 @@ public class LanguageModelConfigurationFromHub { let tokenizerConfig = try? hubApi.configuration(from: "tokenizer_config.json", in: repo) let tokenizerVocab = try hubApi.configuration(from: "tokenizer.json", in: repo) - let configs = Configurations(modelConfig: modelConfig, tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerVocab) + let configs = Configurations( + modelConfig: modelConfig, + tokenizerConfig: tokenizerConfig, + tokenizerData: tokenizerVocab + ) return configs } diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index da51cc8..edce7ba 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -8,25 +8,29 @@ import Foundation public struct HubApi { - var downloadBase: URL - var hfToken: String? - var endpoint: String + public let downloadBase: URL + public let hfToken: String? + public let endpoint: String public typealias RepoType = Hub.RepoType public typealias Repo = Hub.Repo - public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co") { - if downloadBase == nil { + public init( + downloadBase: URL? = nil, + hfToken: String? = nil, + endpoint: String = "https://huggingface.co" + ) { + self.hfToken = hfToken + if let downloadBase { + self.downloadBase = downloadBase + } else { let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first! self.downloadBase = documents.appending(component: "huggingface") - } else { - self.downloadBase = downloadBase! } - self.hfToken = hfToken self.endpoint = endpoint } - static let shared = HubApi() + public static let shared = HubApi() } /// File retrieval @@ -179,7 +183,13 @@ public extension HubApi { let repoDestination = localRepoLocation(repo) for filename in filenames { let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1) - let downloader = HubFileDownloader(repo: repo, repoDestination: repoDestination, relativeFilename: filename, hfToken: hfToken, endpoint: endpoint) + let downloader = HubFileDownloader( + repo: repo, + repoDestination: repoDestination, + relativeFilename: filename, + hfToken: hfToken, + endpoint: endpoint + ) try await downloader.download { fractionDownloaded in fileProgress.completedUnitCount = Int64(100 * fractionDownloaded) progressHandler(progress) diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index 300ed2a..093c212 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -234,8 +234,11 @@ extension AutoTokenizer { return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } - public static func from(pretrained model: String) async throws -> Tokenizer { - let config = LanguageModelConfigurationFromHub(modelName: model) + public static func from( + pretrained model: String, + hubApi: HubApi = .shared + ) async throws -> Tokenizer { + let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi) guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig } let tokenizerData = try await config.tokenizerData