Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve EmbeddingGeneratorExtensions #5551

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Improve EmbeddingGeneratorExtensions
- Renames GenerateAsync extension method (not the interface method) to be GenerateEmbeddingAsync, since it produces a single TEmbedding
- Adds GenerateEmbeddingVectorAsync, which returns a `ReadOnlyMemory<T>`
- Adds a GenerateAndZipEmbeddingsAsync, which creates a `List<KeyValuePair<TInput, TEmbedding>>` that pairs the inputs with the outputs.
  • Loading branch information
stephentoub committed Oct 23, 2024
commit 6c0b0068ce21c8751167483a6f9cf1d6d4635080
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,35 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;

/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TValue,TEmbedding}"/> instances.</summary>
/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> instances.</summary>
public static class EmbeddingGeneratorExtensions
{
/// <summary>Generates an embedding from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TValue">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The type of embedding to generate.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="value">A value from which an embedding will be generated.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
public static async Task<TEmbedding> GenerateAsync<TValue, TEmbedding>(
this IEmbeddingGenerator<TValue, TEmbedding> generator,
TValue value,
/// <returns>
/// The generated embedding for the specified <paramref name="value"/>.
/// </returns>
/// <remarks>
/// This operations is equivalent to using <see cref="IEmbeddingGenerator{TInput, TEmbedding}.GenerateAsync"/> with a
/// collection composed of the single <paramref name="value"/> and then returning the first embedding element from the
/// resulting <see cref="GeneratedEmbeddings{TEmbedding}"/> collection.
/// </remarks>
public static async Task<TEmbedding> GenerateEmbeddingAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> generator,
TInput value,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
where TEmbedding : Embedding
Expand All @@ -37,4 +46,64 @@ public static async Task<TEmbedding> GenerateAsync<TValue, TEmbedding>(

return embeddings[0];
}

/// <summary>Generates an embedding vector from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="value">A value from which an embedding will be generated.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
/// <remarks>
/// This operation is equivalent to using <see cref="GenerateEmbeddingAsync"/> and returning the
/// resulting <see cref="Embedding{T}"/>'s <see cref="Embedding{T}.Vector"/> property.
/// </remarks>
public static async Task<ReadOnlyMemory<TEmbedding>> GenerateEmbeddingVectorAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, Embedding<TEmbedding>> generator,
TInput value,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
{
var embedding = await GenerateEmbeddingAsync(generator, value, options, cancellationToken).ConfigureAwait(false);
return embedding.Vector;
}

/// <summary>
/// Generates embeddings for each of the supplied <paramref name="values"/> and produces a list that pairs
/// each input with its resulting embedding.
/// </summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The type of embedding to generate.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="values">The collection of values for which to generate embeddings.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embeddings.</returns>
public static async Task<IList<KeyValuePair<TInput, TEmbedding>>> GenerateAndZipEmbeddingsAsync<TInput, TEmbedding>(
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
this IEmbeddingGenerator<TInput, TEmbedding> generator,
IEnumerable<TInput> values,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);
_ = Throw.IfNull(values);

IList<TInput> inputs = values as IList<TInput> ?? values.ToList();
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

var embeddings = await generator.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false);
if (embeddings.Count != inputs.Count)
{
throw new InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs ({inputs.Count}).");
}

List<KeyValuePair<TInput, TEmbedding>> results = new(embeddings.Count);
for (int i = 0; i < embeddings.Count; i++)
{
results.Add(new KeyValuePair<TInput, TEmbedding>(inputs[i], embeddings[i]));
}

return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Linq;
using System.Threading.Tasks;
using Xunit;

Expand All @@ -12,7 +13,9 @@ public class EmbeddingGeneratorExtensionsTests
[Fact]
public async Task GenerateAsync_InvalidArgs_ThrowsAsync()
{
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateEmbeddingAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateEmbeddingVectorAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAndZipEmbeddingsAsync(["hello"]));
}

[Fact]
Expand All @@ -26,6 +29,35 @@ public async Task GenerateAsync_ReturnsSingleEmbeddingAsync()
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>([result])
};

Assert.Same(result, await service.GenerateAsync("hello"));
Assert.Same(result, await service.GenerateEmbeddingAsync("hello"));
Assert.Equal(result.Vector, await service.GenerateEmbeddingVectorAsync("hello"));
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(10)]
public async Task GenerateAndZipEmbeddingsAsync_ReturnsExpectedList(int count)
{
string[] inputs = Enumerable.Range(0, count).Select(i => $"hello {i}").ToArray();
Embedding<float>[] embeddings = Enumerable
.Range(0, count)
.Select(i => new Embedding<float>(Enumerable.Range(i, 4).Select(i => (float)i).ToArray()))
.ToArray();

using TestEmbeddingGenerator service = new()
{
GenerateAsyncCallback = (values, options, cancellationToken) =>
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>(new(embeddings))
};

var results = await service.GenerateAndZipEmbeddingsAsync(inputs);
Assert.NotNull(results);
Assert.Equal(count, results.Count);
for (int i = 0; i < count; i++)
{
Assert.Equal(inputs[i], results[i].Key);
Assert.Same(embeddings[i], results[i].Value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ public virtual async Task Caching_SameOutputsForSameInput()
.Use(CreateEmbeddingGenerator()!);

string input = "Red, White, and Blue";
var embedding1 = await generator.GenerateAsync(input);
var embedding2 = await generator.GenerateAsync(input);
var embedding3 = await generator.GenerateAsync(input + "... and Green");
var embedding4 = await generator.GenerateAsync(input);
var embedding1 = await generator.GenerateEmbeddingAsync(input);
var embedding2 = await generator.GenerateEmbeddingAsync(input);
var embedding3 = await generator.GenerateEmbeddingAsync(input + "... and Green");
var embedding4 = await generator.GenerateEmbeddingAsync(input);

var callCounter = generator.GetService<CallCountingEmbeddingGenerator>();
Assert.NotNull(callCounter);
Expand All @@ -114,7 +114,7 @@ public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics()
.UseOpenTelemetry(sourceName: sourceName)
.Use(CreateEmbeddingGenerator()!);

_ = await embeddingGenerator.GenerateAsync("Hello, world!");
_ = await embeddingGenerator.GenerateEmbeddingAsync("Hello, world!");

Assert.Single(activities);
var activity = activities.Single();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ public async Task CachesSuccessResultsAsync()
};

// Make the initial request and do a quick sanity check
var result1 = await outer.GenerateAsync("abc");
var result1 = await outer.GenerateEmbeddingAsync("abc");
AssertEmbeddingsEqual(_expectedEmbedding, result1);
Assert.Equal(1, innerCallCount);

// Act
var result2 = await outer.GenerateAsync("abc");
var result2 = await outer.GenerateEmbeddingAsync("abc");

// Assert
Assert.Equal(1, innerCallCount);
Expand Down Expand Up @@ -134,8 +134,8 @@ public async Task AllowsConcurrentCallsAsync()
};

// Act 1: Concurrent calls before resolution are passed into the inner client
var result1 = outer.GenerateAsync("abc");
var result2 = outer.GenerateAsync("abc");
var result1 = outer.GenerateEmbeddingAsync("abc");
var result2 = outer.GenerateEmbeddingAsync("abc");

// Assert 1
Assert.Equal(2, innerCallCount);
Expand All @@ -146,7 +146,7 @@ public async Task AllowsConcurrentCallsAsync()
AssertEmbeddingsEqual(_expectedEmbedding, await result2);

// Act 2: Subsequent calls after completion are resolved from the cache
var result3 = await outer.GenerateAsync("abc");
var result3 = await outer.GenerateEmbeddingAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, await result1);
}
Expand All @@ -169,12 +169,12 @@ public async Task DoesNotCacheExceptionResultsAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
};

var ex1 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateAsync("abc"));
var ex1 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateEmbeddingAsync("abc"));
Assert.Equal("some failure", ex1.Message);
Assert.Equal(1, innerCallCount);

// Act
var ex2 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateAsync("abc"));
var ex2 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateEmbeddingAsync("abc"));

// Assert
Assert.NotSame(ex1, ex2);
Expand Down Expand Up @@ -207,15 +207,15 @@ public async Task DoesNotCacheCanceledResultsAsync()
};

// First call gets cancelled
var result1 = outer.GenerateAsync("abc");
var result1 = outer.GenerateEmbeddingAsync("abc");
Assert.False(result1.IsCompleted);
Assert.Equal(1, innerCallCount);
resolutionTcs.SetCanceled();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => result1);
Assert.True(result1.IsCanceled);

// Act/Assert: Second call can succeed
var result2 = await outer.GenerateAsync("abc");
var result2 = await outer.GenerateEmbeddingAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}
Expand All @@ -241,11 +241,11 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
};

// Act: Call with two different options
var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});
Expand Down Expand Up @@ -277,11 +277,11 @@ public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
};

// Act: Call with two different options
var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});
Expand Down Expand Up @@ -315,7 +315,7 @@ public async Task CanResolveIDistributedCacheFromDI()

// Act: Make a request that should populate the cache
Assert.Empty(_storage.Keys);
var result = await outer.GenerateAsync("abc");
var result = await outer.GenerateEmbeddingAsync("abc");

// Assert
Assert.NotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level)
.UseLogging()
.Use(innerGenerator);

await generator.GenerateAsync("Blue whale");
await generator.GenerateEmbeddingAsync("Blue whale");

if (level is LogLevel.Trace)
{
Expand Down