Skip to content

Commit

Permalink
Merge branch 'main' into streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
Krivoblotsky committed May 15, 2023
2 parents 8a8aafd + eefb14b commit 4854390
Show file tree
Hide file tree
Showing 20 changed files with 246 additions and 50 deletions.
9 changes: 8 additions & 1 deletion Demo/App/APIProvidedView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import SwiftUI
struct APIProvidedView: View {
@Binding var apiKey: String
@StateObject var chatStore: ChatStore
@StateObject var miscStore: MiscStore
@State var isShowingAPIConfigModal: Bool = true

@Environment(\.idProviderValue) var idProvider
Expand All @@ -28,11 +29,17 @@ struct APIProvidedView: View {
idProvider: idProvider
)
)
self._miscStore = StateObject(
wrappedValue: MiscStore(
openAIClient: OpenAI(apiToken: apiKey.wrappedValue)
)
)
}

var body: some View {
ContentView(
chatStore: chatStore
chatStore: chatStore,
miscStore: miscStore
)
.onChange(of: apiKey) { newApiKey in
chatStore.openAIClient = OpenAI(apiToken: newApiKey)
Expand Down
9 changes: 9 additions & 0 deletions Demo/App/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import SwiftUI

struct ContentView: View {
@ObservedObject var chatStore: ChatStore
@ObservedObject var miscStore: MiscStore
@State private var selectedTab = 0
@Environment(\.idProviderValue) var idProvider

Expand All @@ -37,6 +38,14 @@ struct ContentView: View {
Label("Image", systemImage: "photo")
}
.tag(2)

MiscView(
store: miscStore
)
.tabItem {
Label("Misc", systemImage: "ellipsis")
}
.tag(3)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion Demo/DemoChat/Sources/ChatStore.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Store.swift
// ChatStore.swift
// DemoChat
//
// Created by Sihao Lu on 3/25/23.
Expand Down
94 changes: 94 additions & 0 deletions Demo/DemoChat/Sources/MiscStore.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//
// MiscStore.swift
// DemoChat
//
// Created by Aled Samuel on 22/04/2023.
//

import Foundation
import OpenAI

public final class MiscStore: ObservableObject {
public var openAIClient: OpenAIProtocol

@Published var availableModels: [ModelResult] = []

public init(
openAIClient: OpenAIProtocol
) {
self.openAIClient = openAIClient
}

// MARK: Models

@MainActor
func getModels() async {
do {
let response = try await openAIClient.models()
availableModels = response.data
} catch {
// TODO: Better error handling
print(error.localizedDescription)
}
}

// MARK: Moderations

@Published var moderationConversation = Conversation(id: "", messages: [])
@Published var moderationConversationError: Error?

@MainActor
func sendModerationMessage(_ message: Message) async {
moderationConversation.messages.append(message)
await completeModerationChat(message: message)
}

@MainActor
func completeModerationChat(message: Message) async {

moderationConversationError = nil

do {
let response = try await openAIClient.moderations(
query: ModerationsQuery(
input: message.content,
model: .textModerationLatest
)
)

let categoryResults = response.results

let existingMessages = moderationConversation.messages

func circleEmoji(for resultType: Bool) -> String {
resultType ? "🔴" : "🟢"
}

for result in categoryResults {
let content = """
\(circleEmoji(for: result.categories.hate)) Hate
\(circleEmoji(for: result.categories.hateThreatening)) Hate/Threatening
\(circleEmoji(for: result.categories.selfHarm)) Self-harm
\(circleEmoji(for: result.categories.sexual)) Sexual
\(circleEmoji(for: result.categories.sexualMinors)) Sexual/Minors
\(circleEmoji(for: result.categories.violence)) Violence
\(circleEmoji(for: result.categories.violenceGraphic)) Violence/Graphic
"""

let message = Message(
id: response.id,
role: .assistant,
content: content,
createdAt: message.createdAt)

if existingMessages.contains(message) {
continue
}
moderationConversation.messages.append(message)
}

} catch {
moderationConversationError = error
}
}
}
File renamed without changes.
2 changes: 1 addition & 1 deletion Demo/DemoChat/Sources/UI/Environment/IDProvider.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// IDProvider.swift
//
// DemoChat
//
// Created by Sihao Lu on 4/6/23.
//
Expand Down
27 changes: 27 additions & 0 deletions Demo/DemoChat/Sources/UI/Misc/ListModelsView.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//
// ListModelsView.swift
// DemoChat
//
// Created by Aled Samuel on 22/04/2023.
//

import SwiftUI

public struct ListModelsView: View {
@ObservedObject var store: MiscStore

public var body: some View {
NavigationStack {
List($store.availableModels) { row in
Text(row.id)
}
.listStyle(.insetGrouped)
.navigationTitle("Models")
}
.onAppear {
Task {
await store.getModels()
}
}
}
}
39 changes: 39 additions & 0 deletions Demo/DemoChat/Sources/UI/Misc/MiscView.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//
// MiscView.swift
// DemoChat
//
// Created by Aled Samuel on 22/04/2023.
//

import SwiftUI

public struct MiscView: View {
@ObservedObject var store: MiscStore

public init(store: MiscStore) {
self.store = store
}

public var body: some View {
NavigationStack {
List {
Section(header: Text("Models")) {
NavigationLink("List Models", destination: ListModelsView(store: store))
NavigationLink("Retrieve Model", destination: RetrieveModelView())
}
Section(header: Text("Moderations")) {
NavigationLink("Moderation Chat", destination: ModerationChatView(store: store))
}
}
.listStyle(.insetGrouped)
.navigationTitle("Misc")
}
}
}

struct RetrieveModelView: View {
var body: some View {
Text("Retrieve Model: TBD")
.font(.largeTitle)
}
}
38 changes: 38 additions & 0 deletions Demo/DemoChat/Sources/UI/ModerationChatView.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//
// ModerationChatView.swift
// DemoChat
//
// Created by Aled Samuel on 26/04/2023.
//

import SwiftUI

public struct ModerationChatView: View {
@ObservedObject var store: MiscStore

@Environment(\.dateProviderValue) var dateProvider
@Environment(\.idProviderValue) var idProvider

public init(store: MiscStore) {
self.store = store
}

public var body: some View {
DetailView(
conversation: store.moderationConversation,
error: store.moderationConversationError,
sendMessage: { message in
Task {
await store.sendModerationMessage(
Message(
id: idProvider(),
role: .user,
content: message,
createdAt: dateProvider()
)
)
}
}
)
}
}
13 changes: 3 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,6 @@ You can also pass a custom string if you need to use some model, that is not rep

Lists the currently available models.

**Request**

```swift
public struct ModelsQuery: Codable, Equatable { }
```

**Response**

```swift
Expand All @@ -601,12 +595,11 @@ public struct ModelsResult: Codable, Equatable {
**Example**

```swift
let query = ModelsQuery()
openAI.models(query: query) { result in
openAI.models() { result in
//Handle result here
}
//or
let result = try await openAI.models(query: query)
let result = try await openAI.models()
```

#### Retrieve Model
Expand Down Expand Up @@ -736,7 +729,7 @@ func embeddings(query: EmbeddingsQuery) -> AnyPublisher<EmbeddingsResult, Error>
func chats(query: ChatQuery) -> AnyPublisher<ChatResult, Error>
func edits(query: EditsQuery) -> AnyPublisher<EditsResult, Error>
func model(query: ModelQuery) -> AnyPublisher<ModelResult, Error>
func models(query: ModelsQuery) -> AnyPublisher<ModelsResult, Error>
func models() -> AnyPublisher<ModelsResult, Error>
func moderations(query: ModerationsQuery) -> AnyPublisher<ModerationsResult, Error>
func audioTranscriptions(query: AudioTranscriptionQuery) -> AnyPublisher<AudioTranscriptionResult, Error>
func audioTranslations(query: AudioTranslationQuery) -> AnyPublisher<AudioTranslationResult, Error>
Expand Down
14 changes: 9 additions & 5 deletions Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,15 @@ final public class OpenAI: OpenAIProtocol {
public convenience init(configuration: Configuration) {
self.init(configuration: configuration, session: URLSession.shared)
}
init(configuration: Configuration, session: URLSessionProtocol = URLSession.shared) {

init(configuration: Configuration, session: URLSessionProtocol) {
self.configuration = configuration
self.session = session
}

public convenience init(configuration: Configuration, session: URLSession = URLSession.shared) {
self.init(configuration: configuration, session: session as URLSessionProtocol)
}

public func completions(query: CompletionsQuery, completion: @escaping (Result<CompletionsResult, Error>) -> Void) {
performRequest(request: JSONRequest<CompletionsResult>(body: query, url: buildURL(path: .completions)), completion: completion)
Expand Down Expand Up @@ -81,11 +85,11 @@ final public class OpenAI: OpenAIProtocol {
}

public func model(query: ModelQuery, completion: @escaping (Result<ModelResult, Error>) -> Void) {
performRequest(request: JSONRequest<ModelResult>(body: query, url: buildURL(path: .models.withPath(query.model))), completion: completion)
performRequest(request: JSONRequest<ModelResult>(url: buildURL(path: .models.withPath(query.model)), method: "GET"), completion: completion)
}

public func models(query: ModelsQuery, completion: @escaping (Result<ModelsResult, Error>) -> Void) {
performRequest(request: JSONRequest<ModelsResult>(body: query, url: buildURL(path: .models)), completion: completion)
public func models(completion: @escaping (Result<ModelsResult, Error>) -> Void) {
performRequest(request: JSONRequest<ModelsResult>(url: buildURL(path: .models), method: "GET"), completion: completion)
}

public func moderations(query: ModerationsQuery, completion: @escaping (Result<ModerationsResult, Error>) -> Void) {
Expand Down
8 changes: 5 additions & 3 deletions Sources/OpenAI/Private/JSONRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import FoundationNetworking

final class JSONRequest<ResultType> {

let body: Codable
let body: Codable?
let url: URL
let method: String

init(body: Codable, url: URL, method: String = "POST") {
init(body: Codable? = nil, url: URL, method: String = "POST") {
self.body = body
self.url = url
self.method = method
Expand All @@ -33,7 +33,9 @@ extension JSONRequest: URLRequestBuildable {
request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization")
}
request.httpMethod = method
request.httpBody = try JSONEncoder().encode(body)
if let body = body {
request.httpBody = try JSONEncoder().encode(body)
}
return request
}
}
2 changes: 2 additions & 0 deletions Sources/OpenAI/Public/Models/Models/ModelResult.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ public struct ModelResult: Codable, Equatable {
case ownedBy = "owned_by"
}
}

extension ModelResult: Identifiable {}
2 changes: 1 addition & 1 deletion Sources/OpenAI/Public/Models/Models/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public extension Model {

/// Almost as capable as the latest model, but slightly older.
static let textModerationStable = "text-moderation-stable"
/// Most capable moderation model. Accuracy will be slighlty higher than the stable model.
/// Most capable moderation model. Accuracy will be slightly higher than the stable model.
static let textModerationLatest = "text-moderation-latest"
static let moderation = "text-moderation-001"
}
12 changes: 0 additions & 12 deletions Sources/OpenAI/Public/Models/Models/ModelsQuery.swift

This file was deleted.

Loading

0 comments on commit 4854390

Please sign in to comment.