Skip to content

Commit

Permalink
Fix schema generation for floating point types (#5788)
Browse files Browse the repository at this point in the history
* Fix schema generation for floating point types

* Ensure the correct type keyword is being replaced.
  • Loading branch information
eiriktsarpalis authored Jan 10, 2025
1 parent e4ebc9e commit b86638d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
#if !NET9_0_OR_GREATER
using System.Diagnostics.CodeAnalysis;
#endif
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
Expand All @@ -21,6 +18,7 @@
#pragma warning disable S107 // Methods should not have too many parameters
#pragma warning disable S1075 // URIs should not be hardcoded
#pragma warning disable SA1118 // Parameter should not span multiple lines
#pragma warning disable S109 // Magic numbers should not be used

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -304,14 +302,14 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js

// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// STJ represents .NET integer types as ["string", "integer"], which will then lead to an error.
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema))
// In certain configurations STJ represents .NET numeric types as ["string", "number"], which will then lead to an error.
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType))
{
// We don't want to emit any array for "type". In this case we know it contains "integer"
// We don't want to emit any array for "type". In this case we know it contains "integer" or "number",
// so reduce the type to that alone, assuming it's the most specific type.
// This makes schemas for Int32 (etc) work with Ollama.
JsonObject obj = ConvertSchemaToObject(ref schema);
obj[TypePropertyName] = "integer";
obj[TypePropertyName] = numericType;
_ = obj.Remove(PatternPropertyName);
}
}
Expand Down Expand Up @@ -380,21 +378,32 @@ static JsonObject ConvertSchemaToObject(ref JsonNode schema)
}
}

private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema)
private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType)
{
if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray)
numericType = null;

if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray { Count: 2 } typeArray)
{
int count = 0;
bool allowString = false;

foreach (JsonNode? entry in typeArray)
{
if (entry?.GetValueKind() is JsonValueKind.String &&
entry.GetValue<string>() is "integer" or "string")
entry.GetValue<string>() is string type)
{
count++;
switch (type)
{
case "integer" or "number":
numericType = type;
break;
case "string":
allowString = true;
break;
}
}
}

return count == typeArray.Count;
return allowString && numericType is not null;
}

return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,22 @@ public static void ResolveParameterJsonSchema_ReturnsExpectedValue()
[Fact]
public static void CreateParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString()
{
JsonElement expected = JsonDocument.Parse("""
{
"type": "integer"
}
""").RootElement;

JsonSerializerOptions options = new(JsonSerializerOptions.Default) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
AIFunction func = AIFunctionFactory.Create((int a, int? b, long c, short d) => { }, serializerOptions: options);
AIFunction func = AIFunctionFactory.Create((int a, int? b, long c, short d, float e, double f, decimal g) => { }, serializerOptions: options);

AIFunctionMetadata metadata = func.Metadata;
foreach (var param in metadata.Parameters)
{
string numericType = Type.GetTypeCode(param.ParameterType) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal
? "number"
: "integer";

JsonElement expected = JsonDocument.Parse($$"""
{
"type": "{{numericType}}"
}
""").RootElement;

JsonElement actualSchema = Assert.IsType<JsonElement>(param.Schema);
Assert.True(JsonElement.DeepEquals(expected, actualSchema));
}
Expand Down

0 comments on commit b86638d

Please sign in to comment.