From 1e63a556a0056f814cbca7e8fa0594991ac7a093 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 8 Oct 2024 20:47:21 -0400 Subject: [PATCH] [Vertex AI] Use `struct` instead of `enum` for `HarmProbability` (#13854) --- FirebaseVertexAI/CHANGELOG.md | 7 +- .../ChatSample/Views/ErrorDetailsView.swift | 5 +- FirebaseVertexAI/Sources/Safety.swift | 79 +++++++++++++------ .../Tests/Unit/GenerativeModelTests.swift | 7 +- 4 files changed, 67 insertions(+), 31 deletions(-) diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index 3aa08503f17..5e90031a137 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -36,9 +36,10 @@ as input. (#13767) - [changed] **Breaking Change**: All initializers for `ModelContent` now require the label `parts: `. (#13832) -- [changed] **Breaking Change**: `HarmCategory` is now a struct instead of an - enum type and the `unknown` case has been removed; in a `switch` statement, - use the `default:` case to cover unknown or unhandled categories. (#13728) +- [changed] **Breaking Change**: `HarmCategory` and `HarmProbability` are now + structs instead of enums types and the `unknown` cases have been removed; in a + `switch` statement, use the `default:` case to cover unknown or unhandled + categories or probabilities. (#13728, #13854) - [changed] The default request timeout is now 180 seconds instead of the platform-default value of 60 seconds for a `URLRequest`; this timeout may still be customized in `RequestOptions`. (#13722) diff --git a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift index 11ba86bb1b9..279f02b81fc 100644 --- a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift +++ b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift @@ -25,8 +25,7 @@ private extension HarmCategory { case .hateSpeech: "Hate speech" case .sexuallyExplicit: "Sexually explicit" case .civicIntegrity: "Civic integrity" - default: - "Unknown HarmCategory: \(rawValue)" + default: "Unknown HarmCategory: \(rawValue)" } } } @@ -39,7 +38,7 @@ private extension SafetyRating.HarmProbability { case .low: "Low" case .medium: "Medium" case .negligible: "Negligible" - case .unknown: "Unknown" + default: "Unknown HarmProbability: \(rawValue)" } } } diff --git a/FirebaseVertexAI/Sources/Safety.swift b/FirebaseVertexAI/Sources/Safety.swift index 3f6ce4658c1..280771d0074 100644 --- a/FirebaseVertexAI/Sources/Safety.swift +++ b/FirebaseVertexAI/Sources/Safety.swift @@ -38,24 +38,66 @@ public struct SafetyRating: Equatable, Hashable, Sendable { self.probability = probability } - /// The probability that a given model output falls under a harmful content category. This does - /// not indicate the severity of harm for a piece of content. - public enum HarmProbability: String, Sendable { - /// Unknown. A new server value that isn't recognized by the SDK. - case unknown = "UNKNOWN" + /// The probability that a given model output falls under a harmful content category. + /// + /// > Note: This does not indicate the severity of harm for a piece of content. + public struct HarmProbability: Sendable, Equatable, Hashable { + enum Kind: String { + case negligible = "NEGLIGIBLE" + case low = "LOW" + case medium = "MEDIUM" + case high = "HIGH" + } - /// The probability is zero or close to zero. For benign content, the probability across all - /// categories will be this value. - case negligible = "NEGLIGIBLE" + /// The probability is zero or close to zero. + /// + /// For benign content, the probability across all categories will be this value. + public static var negligible: HarmProbability { + return self.init(kind: .negligible) + } /// The probability is small but non-zero. - case low = "LOW" + public static var low: HarmProbability { + return self.init(kind: .low) + } /// The probability is moderate. - case medium = "MEDIUM" + public static var medium: HarmProbability { + return self.init(kind: .medium) + } + + /// The probability is high. + /// + /// The content described is very likely harmful. + public static var high: HarmProbability { + return self.init(kind: .high) + } + + /// Returns the raw string representation of the `HarmProbability` value. + /// + /// > Note: This value directly corresponds to the values in the [REST + /// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#SafetyRating). + public let rawValue: String - /// The probability is high. The content described is very likely harmful. - case high = "HIGH" + init(kind: Kind) { + rawValue = kind.rawValue + } + + init(rawValue: String) { + if Kind(rawValue: rawValue) == nil { + VertexLog.error( + code: .generateContentResponseUnrecognizedHarmProbability, + """ + Unrecognized HarmProbability with value "\(rawValue)": + - Check for updates to the SDK as support for "\(rawValue)" may have been added; see \ + release notes at https://firebase.google.com/support/release-notes/ios + - Search for "\(rawValue)" in the Firebase Apple SDK Issue Tracker at \ + https://github.com/firebase/firebase-ios-sdk/issues and file a Bug Report if none found + """ + ) + } + self.rawValue = rawValue + } } } @@ -163,17 +205,8 @@ public struct HarmCategory: Sendable, Equatable, Hashable { @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) extension SafetyRating.HarmProbability: Decodable { public init(from decoder: Decoder) throws { - let value = try decoder.singleValueContainer().decode(String.self) - guard let decodedProbability = SafetyRating.HarmProbability(rawValue: value) else { - VertexLog.error( - code: .generateContentResponseUnrecognizedHarmProbability, - "Unrecognized HarmProbability with value \"\(value)\"." - ) - self = .unknown - return - } - - self = decodedProbability + let rawValue = try decoder.singleValueContainer().decode(String.self) + self = SafetyRating.HarmProbability(rawValue: rawValue) } } diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index 254f81e96fb..21076991003 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -162,7 +162,10 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContent_success_unknownEnum_safetyRatings() async throws { let expectedSafetyRatings = [ SafetyRating(category: .harassment, probability: .medium), - SafetyRating(category: .dangerousContent, probability: .unknown), + SafetyRating( + category: .dangerousContent, + probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY") + ), SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high), ] MockURLProtocol @@ -974,7 +977,7 @@ final class GenerativeModelTests: XCTestCase { ) let unknownSafetyRating = SafetyRating( category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"), - probability: .unknown + probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM") ) var foundUnknownSafetyRating = false