Skip to content

Commit

Permalink
Add an extension method for registering custom AIContent types (#5789)
Browse files Browse the repository at this point in the history
* Add an extension method for registering custom AIContent types

* Fix method chaining.
  • Loading branch information
eiriktsarpalis authored Jan 10, 2025
1 parent b86638d commit 309f2b7
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S1121 // Assignments should not be made from within sub-expressions

namespace Microsoft.Extensions.AI;

public static partial class AIJsonUtilities
{
/// <summary>
/// Adds a custom content type to the polymorphic configuration for <see cref="AIContent"/>.
/// </summary>
/// <typeparam name="TContent">The custom content type to configure.</typeparam>
/// <param name="options">The options instance to configure.</param>
/// <param name="typeDiscriminatorId">The type discriminator id for the content type.</param>
/// <exception cref="ArgumentNullException"><paramref name="options"/> or <paramref name="typeDiscriminatorId"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException"><typeparamref name="TContent"/> is a built-in content type.</exception>
/// <exception cref="InvalidOperationException"><paramref name="options"/> is a read-only instance.</exception>
public static void AddAIContentType<TContent>(this JsonSerializerOptions options, string typeDiscriminatorId)
where TContent : AIContent
{
_ = Throw.IfNull(options);
_ = Throw.IfNull(typeDiscriminatorId);

AddAIContentTypeCore(options, typeof(TContent), typeDiscriminatorId);
}

/// <summary>
/// Adds a custom content type to the polymorphic configuration for <see cref="AIContent"/>.
/// </summary>
/// <param name="options">The options instance to configure.</param>
/// <param name="contentType">The custom content type to configure.</param>
/// <param name="typeDiscriminatorId">The type discriminator id for the content type.</param>
/// <exception cref="ArgumentNullException"><paramref name="options"/>, <paramref name="contentType"/>, or <paramref name="typeDiscriminatorId"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException"><paramref name="contentType"/> is a built-in content type or does not derived from <see cref="AIContent"/>.</exception>
/// <exception cref="InvalidOperationException"><paramref name="options"/> is a read-only instance.</exception>
public static void AddAIContentType(this JsonSerializerOptions options, Type contentType, string typeDiscriminatorId)
{
_ = Throw.IfNull(options);
_ = Throw.IfNull(contentType);
_ = Throw.IfNull(typeDiscriminatorId);

if (!typeof(AIContent).IsAssignableFrom(contentType))
{
Throw.ArgumentException(nameof(contentType), "The content type must derive from AIContent.");
}

AddAIContentTypeCore(options, contentType, typeDiscriminatorId);
}

private static void AddAIContentTypeCore(JsonSerializerOptions options, Type contentType, string typeDiscriminatorId)
{
if (contentType.Assembly == typeof(AIContent).Assembly)
{
Throw.ArgumentException(nameof(contentType), "Cannot register built-in AI content types.");
}

IJsonTypeInfoResolver resolver = options.TypeInfoResolver ?? DefaultOptions.TypeInfoResolver!;
options.TypeInfoResolver = resolver.WithAddedModifier(typeInfo =>
{
if (typeInfo.Type == typeof(AIContent))
{
(typeInfo.PolymorphismOptions ??= new()).DerivedTypes.Add(new(contentType, typeDiscriminatorId));
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,68 @@ public static void CreateJsonSchema_ValidateWithTestData(ITestData testData)
JsonNode? serializedValue = JsonSerializer.SerializeToNode(testData.Value, testData.Type, options);
SchemaTestHelpers.AssertDocumentMatchesSchema(schemaAsNode, serializedValue);
}

[Fact]
public static void AddAIContentType_DerivedAIContent()
{
JsonSerializerOptions options = new();
options.AddAIContentType<DerivedAIContent>("derivativeContent");

AIContent c = new DerivedAIContent { DerivedValue = 42 };
string json = JsonSerializer.Serialize(c, options);
Assert.Equal("""{"$type":"derivativeContent","DerivedValue":42,"AdditionalProperties":null}""", json);

AIContent? deserialized = JsonSerializer.Deserialize<AIContent>(json, options);
Assert.IsType<DerivedAIContent>(deserialized);
}

[Fact]
public static void AddAIContentType_ReadOnlyJsonSerializerOptions_ThrowsInvalidOperationException()
{
Assert.Throws<InvalidOperationException>(() => AIJsonUtilities.DefaultOptions.AddAIContentType<DerivedAIContent>("derivativeContent"));
}

[Fact]
public static void AddAIContentType_NonAIContent_ThrowsArgumentException()
{
JsonSerializerOptions options = new();
Assert.Throws<ArgumentException>(() => options.AddAIContentType(typeof(int), "discriminator"));
Assert.Throws<ArgumentException>(() => options.AddAIContentType(typeof(object), "discriminator"));
Assert.Throws<ArgumentException>(() => options.AddAIContentType(typeof(ChatMessage), "discriminator"));
}

[Fact]
public static void AddAIContentType_BuiltInAIContent_ThrowsArgumentException()
{
JsonSerializerOptions options = new();
Assert.Throws<ArgumentException>(() => options.AddAIContentType<AIContent>("discriminator"));
Assert.Throws<ArgumentException>(() => options.AddAIContentType<TextContent>("discriminator"));
}

[Fact]
public static void AddAIContentType_ConflictingIdentifier_ThrowsInvalidOperationException()
{
JsonSerializerOptions options = new();
options.AddAIContentType<DerivedAIContent>("text");
options.AddAIContentType<DerivedAIContent>("audio");

AIContent c = new DerivedAIContent();
Assert.Throws<InvalidOperationException>(() => JsonSerializer.Serialize(c, options));
}

[Fact]
public static void AddAIContentType_NullArguments_ThrowsArgumentNullException()
{
JsonSerializerOptions options = new();
Assert.Throws<ArgumentNullException>(() => ((JsonSerializerOptions)null!).AddAIContentType<DerivedAIContent>("discriminator"));
Assert.Throws<ArgumentNullException>(() => ((JsonSerializerOptions)null!).AddAIContentType(typeof(DerivedAIContent), "discriminator"));
Assert.Throws<ArgumentNullException>(() => options.AddAIContentType<DerivedAIContent>(null!));
Assert.Throws<ArgumentNullException>(() => options.AddAIContentType(typeof(DerivedAIContent), null!));
Assert.Throws<ArgumentNullException>(() => options.AddAIContentType(null!, "discriminator"));
}

private class DerivedAIContent : AIContent
{
public int DerivedValue { get; set; }
}
}

0 comments on commit 309f2b7

Please sign in to comment.