Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve EmbeddingGeneratorExtensions #5551

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,59 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S2302 // "nameof" should be used

namespace Microsoft.Extensions.AI;

/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TValue,TEmbedding}"/> instances.</summary>
/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> instances.</summary>
public static class EmbeddingGeneratorExtensions
{
/// <summary>Generates an embedding from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TValue">The type from which embeddings will be generated.</typeparam>
/// <summary>Generates an embedding vector from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="value">A value from which an embedding will be generated.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
public static async Task<TEmbedding> GenerateAsync<TValue, TEmbedding>(
this IEmbeddingGenerator<TValue, TEmbedding> generator,
TValue value,
/// <remarks>
/// This operation is equivalent to using <see cref="GenerateEmbeddingAsync"/> and returning the
/// resulting <see cref="Embedding{T}"/>'s <see cref="Embedding{T}.Vector"/> property.
/// </remarks>
public static async Task<ReadOnlyMemory<TEmbedding>> GenerateEmbeddingVectorAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, Embedding<TEmbedding>> generator,
TInput value,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
{
var embedding = await GenerateEmbeddingAsync(generator, value, options, cancellationToken).ConfigureAwait(false);
return embedding.Vector;
}

/// <summary>Generates an embedding from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The type of embedding to generate.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="value">A value from which an embedding will be generated.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>
/// The generated embedding for the specified <paramref name="value"/>.
/// </returns>
/// <remarks>
/// This operations is equivalent to using <see cref="IEmbeddingGenerator{TInput, TEmbedding}.GenerateAsync"/> with a
/// collection composed of the single <paramref name="value"/> and then returning the first embedding element from the
/// resulting <see cref="GeneratedEmbeddings{TEmbedding}"/> collection.
/// </remarks>
public static async Task<TEmbedding> GenerateEmbeddingAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> generator,
TInput value,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
where TEmbedding : Embedding
Expand All @@ -30,11 +63,61 @@ public static async Task<TEmbedding> GenerateAsync<TValue, TEmbedding>(
_ = Throw.IfNull(value);

var embeddings = await generator.GenerateAsync([value], options, cancellationToken).ConfigureAwait(false);

if (embeddings is null)
{
throw new InvalidOperationException("Embedding generator returned a null collection of embeddings.");
}

if (embeddings.Count != 1)
{
throw new InvalidOperationException("Expected exactly one embedding to be generated.");
throw new InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs (1).");
}

return embeddings[0] ?? throw new InvalidOperationException("Embedding generator generated a null embedding.");
}

/// <summary>
/// Generates embeddings for each of the supplied <paramref name="values"/> and produces a list that pairs
/// each input value with its resulting embedding.
/// </summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The type of embedding to generate.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="values">The collection of values for which to generate embeddings.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>An array containing tuples of the input values and the associated generated embeddings.</returns>
public static async Task<(TInput Value, TEmbedding Embedding)[]> GenerateAndZipAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> generator,
IEnumerable<TInput> values,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);
_ = Throw.IfNull(values);

IList<TInput> inputs = values as IList<TInput> ?? values.ToList();
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
int inputsCount = inputs.Count;

if (inputsCount == 0)
{
return Array.Empty<(TInput, TEmbedding)>();
}

var embeddings = await generator.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false);
if (embeddings.Count != inputsCount)
{
throw new InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs ({inputsCount}).");
}

var results = new (TInput, TEmbedding)[embeddings.Count];
for (int i = 0; i < results.Length; i++)
{
results[i] = (inputs[i], embeddings[i]);
}

return embeddings[0];
return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Linq;
using System.Threading.Tasks;
using Xunit;

Expand All @@ -12,7 +13,9 @@ public class EmbeddingGeneratorExtensionsTests
[Fact]
public async Task GenerateAsync_InvalidArgs_ThrowsAsync()
{
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateEmbeddingAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateEmbeddingVectorAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAndZipAsync(["hello"]));
}

[Fact]
Expand All @@ -26,6 +29,35 @@ public async Task GenerateAsync_ReturnsSingleEmbeddingAsync()
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>([result])
};

Assert.Same(result, await service.GenerateAsync("hello"));
Assert.Same(result, await service.GenerateEmbeddingAsync("hello"));
Assert.Equal(result.Vector, await service.GenerateEmbeddingVectorAsync("hello"));
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(10)]
public async Task GenerateAndZipEmbeddingsAsync_ReturnsExpectedList(int count)
{
string[] inputs = Enumerable.Range(0, count).Select(i => $"hello {i}").ToArray();
Embedding<float>[] embeddings = Enumerable
.Range(0, count)
.Select(i => new Embedding<float>(Enumerable.Range(i, 4).Select(i => (float)i).ToArray()))
.ToArray();

using TestEmbeddingGenerator service = new()
{
GenerateAsyncCallback = (values, options, cancellationToken) =>
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>(new(embeddings))
};

var results = await service.GenerateAndZipAsync(inputs);
Assert.NotNull(results);
Assert.Equal(count, results.Length);
for (int i = 0; i < count; i++)
{
Assert.Equal(inputs[i], results[i].Value);
Assert.Same(embeddings[i], results[i].Embedding);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ public virtual async Task Caching_SameOutputsForSameInput()
.Use(CreateEmbeddingGenerator()!);

string input = "Red, White, and Blue";
var embedding1 = await generator.GenerateAsync(input);
var embedding2 = await generator.GenerateAsync(input);
var embedding3 = await generator.GenerateAsync(input + "... and Green");
var embedding4 = await generator.GenerateAsync(input);
var embedding1 = await generator.GenerateEmbeddingAsync(input);
var embedding2 = await generator.GenerateEmbeddingAsync(input);
var embedding3 = await generator.GenerateEmbeddingAsync(input + "... and Green");
var embedding4 = await generator.GenerateEmbeddingAsync(input);

var callCounter = generator.GetService<CallCountingEmbeddingGenerator>();
Assert.NotNull(callCounter);
Expand All @@ -114,7 +114,7 @@ public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics()
.UseOpenTelemetry(sourceName: sourceName)
.Use(CreateEmbeddingGenerator()!);

_ = await embeddingGenerator.GenerateAsync("Hello, world!");
_ = await embeddingGenerator.GenerateEmbeddingAsync("Hello, world!");

Assert.Single(activities);
var activity = activities.Single();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ public async Task CachesSuccessResultsAsync()
};

// Make the initial request and do a quick sanity check
var result1 = await outer.GenerateAsync("abc");
var result1 = await outer.GenerateEmbeddingAsync("abc");
AssertEmbeddingsEqual(_expectedEmbedding, result1);
Assert.Equal(1, innerCallCount);

// Act
var result2 = await outer.GenerateAsync("abc");
var result2 = await outer.GenerateEmbeddingAsync("abc");

// Assert
Assert.Equal(1, innerCallCount);
Expand Down Expand Up @@ -134,8 +134,8 @@ public async Task AllowsConcurrentCallsAsync()
};

// Act 1: Concurrent calls before resolution are passed into the inner client
var result1 = outer.GenerateAsync("abc");
var result2 = outer.GenerateAsync("abc");
var result1 = outer.GenerateEmbeddingAsync("abc");
var result2 = outer.GenerateEmbeddingAsync("abc");

// Assert 1
Assert.Equal(2, innerCallCount);
Expand All @@ -146,7 +146,7 @@ public async Task AllowsConcurrentCallsAsync()
AssertEmbeddingsEqual(_expectedEmbedding, await result2);

// Act 2: Subsequent calls after completion are resolved from the cache
var result3 = await outer.GenerateAsync("abc");
var result3 = await outer.GenerateEmbeddingAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, await result1);
}
Expand All @@ -169,12 +169,12 @@ public async Task DoesNotCacheExceptionResultsAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
};

var ex1 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateAsync("abc"));
var ex1 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateEmbeddingAsync("abc"));
Assert.Equal("some failure", ex1.Message);
Assert.Equal(1, innerCallCount);

// Act
var ex2 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateAsync("abc"));
var ex2 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateEmbeddingAsync("abc"));

// Assert
Assert.NotSame(ex1, ex2);
Expand Down Expand Up @@ -207,15 +207,15 @@ public async Task DoesNotCacheCanceledResultsAsync()
};

// First call gets cancelled
var result1 = outer.GenerateAsync("abc");
var result1 = outer.GenerateEmbeddingAsync("abc");
Assert.False(result1.IsCompleted);
Assert.Equal(1, innerCallCount);
resolutionTcs.SetCanceled();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => result1);
Assert.True(result1.IsCanceled);

// Act/Assert: Second call can succeed
var result2 = await outer.GenerateAsync("abc");
var result2 = await outer.GenerateEmbeddingAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}
Expand All @@ -241,11 +241,11 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
};

// Act: Call with two different options
var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});
Expand Down Expand Up @@ -277,11 +277,11 @@ public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
};

// Act: Call with two different options
var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});
Expand Down Expand Up @@ -315,7 +315,7 @@ public async Task CanResolveIDistributedCacheFromDI()

// Act: Make a request that should populate the cache
Assert.Empty(_storage.Keys);
var result = await outer.GenerateAsync("abc");
var result = await outer.GenerateEmbeddingAsync("abc");

// Assert
Assert.NotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level)
.UseLogging()
.Use(innerGenerator);

await generator.GenerateAsync("Blue whale");
await generator.GenerateEmbeddingAsync("Blue whale");

if (level is LogLevel.Trace)
{
Expand Down
Loading