diff --git a/dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageCollection.cs b/dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageCollection.cs index 492bdbf76..fb48d8372 100644 --- a/dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageCollection.cs +++ b/dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageCollection.cs @@ -480,7 +480,7 @@ public static IEnumerable GetSlice(SqliteDatabase db, int startOrdin ArgumentVerify.ThrowIfGreaterThan(startOrdinal, endOrdinal, nameof(startOrdinal)); return db.Enumerate(@" -SELECT chunks, chunk_uri, message_length, start_timestamp, tags, metadata, extra +SELECT msg_id, chunks, chunk_uri, message_length, start_timestamp, tags, metadata, extra FROM Messages WHERE msg_id >= @start_id AND msg_id < @end_id ORDER BY msg_id", (cmd) => diff --git a/dotnet/typeagent/src/typechat/TypeChat.csproj b/dotnet/typeagent/src/typechat/TypeChat.csproj index af11aba87..34e387653 100644 --- a/dotnet/typeagent/src/typechat/TypeChat.csproj +++ b/dotnet/typeagent/src/typechat/TypeChat.csproj @@ -19,4 +19,8 @@ + + + + diff --git a/dotnet/typeagent/tests/conversationMemory.test/PodcastTests.cs b/dotnet/typeagent/tests/conversationMemory.test/PodcastTests.cs index 2f835e4ea..ba0c211a8 100644 --- a/dotnet/typeagent/tests/conversationMemory.test/PodcastTests.cs +++ b/dotnet/typeagent/tests/conversationMemory.test/PodcastTests.cs @@ -27,8 +27,6 @@ private class TestTranscriptInfo public string name { get; set; } = string.Empty; public System.DateTime date { get; set; } = System.DateTime.Now; public uint length { get; set; } = 0; - public uint? participantCount { get; set; } = null; - public uint? messageCount { get; set; } = null; } private static TestTranscriptInfo GetTransscriptSmall() @@ -39,8 +37,6 @@ private static TestTranscriptInfo GetTransscriptSmall() name = "Test", date = System.DateTime.Parse("March 2024"), length = 15, - messageCount = 7, - participantCount = 5, }; } @@ -62,7 +58,7 @@ public async Task BuildIndexAsync() Assert.Equal(["hamlet", "lady bracknell", "macbeth", "richard", "sherlock holmes"], participants); var terms = await podcast.SemanticRefIndex.LookupTermAsync("misfortune"); - Assert.True(terms?.Count > 0); + Assert.True(terms!.Count > 0); } private async Task ImportTestPodcastAsync(TestTranscriptInfo podcastDetails, bool online) diff --git a/dotnet/typeagent/tests/testLib/MockModels.cs b/dotnet/typeagent/tests/testLib/MockModels.cs new file mode 100644 index 000000000..c8465a918 --- /dev/null +++ b/dotnet/typeagent/tests/testLib/MockModels.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.TypeChat; +using TypeAgent.AIClient; + +namespace TypeAgent.TestLib; + +public class MockModel_No_JSON_Response : IChatModel +{ + public Task CompleteAsync(Prompt prompt, TranslationSettings? settings, CancellationToken cancelToken) + { + + return Task.Run(() => "Mock response", cancelToken); + } + + public Task CompleteTextAsync(Prompt prompt, TranslationSettings? settings, CancellationToken cancelToken) + { + return Task.Run(() => "Mock response", cancelToken); + } +} + +public class MockModel_Partial_JSON_Response : IChatModel +{ + public Task CompleteAsync(Prompt prompt, TranslationSettings? settings, CancellationToken cancelToken) + { + + return Task.Run(() => "{ \"text\": \"partial json\"", cancelToken); + } + + public Task CompleteTextAsync(Prompt prompt, TranslationSettings? settings, CancellationToken cancelToken) + { + return Task.Run(() => "{ \"text\": \"partial json\"", cancelToken); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/ArgumentVerifyTests.cs b/dotnet/typeagent/tests/typeChat.test/ArgumentVerifyTests.cs new file mode 100644 index 000000000..847c82930 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/ArgumentVerifyTests.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Xunit; + +namespace Microsoft.TypeChat.Tests; + +public class ArgumentVerifyTests +{ + [Fact] + public void Throw_ThrowsArgumentException_WithMessage() + { + var ex = Assert.Throws(() => ArgumentVerify.Throw("test message")); + Assert.Equal("test message", ex.Message); + } + + [Fact] + public void ThrowIfNull_WithNull_ThrowsArgumentNullException() + { + var ex = Assert.Throws(() => ArgumentVerify.ThrowIfNull(null, "param")); + Assert.Equal("param", ex.ParamName); + } + + [Fact] + public void ThrowIfNull_WithNonNull_DoesNotThrow() + { + ArgumentVerify.ThrowIfNull(new object(), "param"); + } + + [Fact] + public void ThrowIfNullOrEmpty_StringNull_ThrowsArgumentNullException() + { + var ex = Assert.Throws(() => ArgumentVerify.ThrowIfNullOrEmpty((string)null, "param")); + Assert.Equal("param", ex.ParamName); + } + + [Fact] + public void ThrowIfNullOrEmpty_StringEmpty_ThrowsArgumentException() + { + var ex = Assert.Throws(() => ArgumentVerify.ThrowIfNullOrEmpty("", "param")); + Assert.Equal("The value cannot be an empty string. (Parameter 'param')", ex.Message); + } + + [Fact] + public void ThrowIfNullOrEmpty_StringValid_DoesNotThrow() + { + ArgumentVerify.ThrowIfNullOrEmpty("valid", "param"); + } + + [Fact] + public void ThrowIfNullOrEmpty_ListNull_ThrowsArgumentNullException() + { + List list = null; + var ex = Assert.Throws(() => ArgumentVerify.ThrowIfNullOrEmpty(list, "param")); + Assert.Equal("param", ex.ParamName); + } + + [Fact] + public void ThrowIfNullOrEmpty_ListEmpty_ThrowsArgumentException() + { + var ex = Assert.Throws(() => ArgumentVerify.ThrowIfNullOrEmpty(new List(), "param")); + Assert.Equal("The list cannot be empty. (Parameter 'param')", ex.Message); + } + + [Fact] + public void ThrowIfNullOrEmpty_ListValid_DoesNotThrow() + { + ArgumentVerify.ThrowIfNullOrEmpty(new List { 1 }, "param"); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/EmbeddingExtensionTests.cs b/dotnet/typeagent/tests/typeChat.test/EmbeddingExtensionTests.cs new file mode 100644 index 000000000..16056e044 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/EmbeddingExtensionTests.cs @@ -0,0 +1,449 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using TypeAgent.Common; +using TypeAgent.Vector; +using Xunit; + +namespace Microsoft.TypeChat.Tests; + +public class EmbeddingExtensionTests +{ + // Mock implementation for testing + private class MockEmbedding : ICosineSimilarity + { + public float[] Vector { get; } + + public MockEmbedding(float[] vector) + { + Vector = vector; + } + + public double CosineSimilarity(MockEmbedding other) + { + // Simple dot product for normalized vectors + double sum = 0; + for (int i = 0; i < Vector.Length && i < other.Vector.Length; i++) + { + sum += Vector[i] * other.Vector[i]; + } + return sum; + } + } + + [Fact] + public void IndexOfNearest_ReturnsCorrectIndex() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 0.0f, 1.0f }) + }; + var query = new MockEmbedding(new float[] { 0.1f, 0.9f, 0.1f }); + + // Act + var result = list.IndexOfNearest(query); + + // Assert + Assert.Equal(1, result.Item); + Assert.True(result.Score > 0.5); + } + + [Fact] + public void IndexOfNearest_EmptyList_ReturnsNegativeIndex() + { + // Arrange + var list = new List(); + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + + // Act + var result = list.IndexOfNearest(query); + + // Assert + Assert.Equal(-1, result.Item); + Assert.Equal(double.MinValue, result.Score); + } + + [Fact] + public void IndexOfNearest_WithMinScore_FiltersResults() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f }), + new MockEmbedding(new float[] { 0.5f, 0.5f, 0.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + double minScore = 0.9; + + // Act + var result = list.IndexOfNearest(query, minScore); + + // Assert + Assert.Equal(0, result.Item); + Assert.True(result.Score >= minScore); + } + + [Fact] + public void IndexOfNearest_WithMinScore_NoMatch_ReturnsNegative() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 0.0f, 1.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + double minScore = 0.9; + + // Act + var result = list.IndexOfNearest(query, minScore); + + // Assert + Assert.Equal(-1, result.Item); + Assert.Equal(double.MinValue, result.Score); + } + + [Fact] + public void IndexesOfNearest_WithTopNCollection_ReturnsTopMatches() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), + new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 0.0f, 1.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + var matches = TopNCollection.Create(2); + + // Act + list.IndexesOfNearest(query, matches); + var results = matches.ByRankAndClear(); + + // Assert + Assert.Equal(2, results.Count); + Assert.Equal(0, results[0].Item); // Closest match + Assert.Equal(1, results[1].Item); // Second closest + Assert.True(results[0].Score > results[1].Score); + } + + [Fact] + public void IndexesOfNearest_WithMinScore_FiltersResults() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), + new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + var matches = TopNCollection.Create(10); + double minScore = 0.8; + + // Act + list.IndexesOfNearest(query, matches, minScore); + var results = matches.ByRankAndClear(); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.Score >= minScore)); + } + + [Fact] + public void IndexesOfNearest_ThrowsOnNullMatches() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + + // Act & Assert + Assert.Throws(() => + list.IndexesOfNearest(query, null!)); + } + + [Fact] + public void IndexesOfNearest_WithMaxMatches_ReturnsCorrectCount() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), + new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f }), + new MockEmbedding(new float[] { 0.8f, 0.2f, 0.0f }), + new MockEmbedding(new float[] { 0.7f, 0.3f, 0.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + int maxMatches = 2; + + // Act + var results = list.IndexesOfNearest(query, maxMatches); + + // Assert + Assert.Equal(2, results.Count); + Assert.Equal(0, results[0].Item); + Assert.Equal(1, results[1].Item); + } + + [Fact] + public void IndexesOfNearest_WithFilter_AppliesFilter() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), // index 0 + new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f }), // index 1 + new MockEmbedding(new float[] { 0.8f, 0.2f, 0.0f }), // index 2 + new MockEmbedding(new float[] { 0.7f, 0.3f, 0.0f }) // index 3 + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + Func filter = (i) => i % 2 == 0; // Only even indexes + int maxMatches = 2; + + // Act + var results = list.IndexesOfNearest(query, filter, maxMatches); + + // Assert - Note: Current implementation has a bug, it doesn't add filtered items to matches + // This test documents the current behavior + Assert.Empty(results); + } + + [Fact] + public void IndexesOfNearest_WithFilter_ThrowsOnNullFilter() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + + // Act & Assert + Assert.Throws(() => + list.IndexesOfNearest(query, null!, 10)); + } + + [Fact] + public void IndexesOfNearestInSubset_ReturnsMatchesFromSubset() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), // index 0 + new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f }), // index 1 + new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f }), // index 2 + new MockEmbedding(new float[] { 0.0f, 0.0f, 1.0f }) // index 3 + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + var subset = new List { 1, 2, 3 }; // Exclude index 0 + int maxMatches = 2; + + // Act + var results = list.IndexesOfNearestInSubset(query, subset, maxMatches); + + // Assert + Assert.Equal(2, results.Count); + Assert.Equal(2, results[0].Item); // Best match in subset + Assert.DoesNotContain(results, r => r.Item == 0); // Index 0 not in subset + } + + [Fact] + public void IndexesOfNearestInSubset_WithMinScore_FiltersResults() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), + new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + var subset = new List { 0, 1, 2 }; + double minScore = 0.8; + + // Act + var results = list.IndexesOfNearestInSubset(query, subset, 10, minScore); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.Score >= minScore)); + } + + [Fact] + public void IndexesOfNearestInSubset_ThrowsOnNullSubset() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + + // Act & Assert + Assert.Throws(() => + list.IndexesOfNearestInSubset(query, null!, 10)); + } + + [Fact] + public void KeysOfNearest_WithTopNCollection_ReturnsTopKeys() + { + // Arrange + var list = new List> + { + new KeyValuePair(100, new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f })), + new KeyValuePair(200, new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f })), + new KeyValuePair(300, new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f })) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + var matches = TopNCollection.Create(2); + + // Act + list.KeysOfNearest(query, matches); + var results = matches.ByRankAndClear(); + + // Assert + Assert.Equal(2, results.Count); + Assert.Equal(100, results[0].Item); + Assert.Equal(200, results[1].Item); + } + + [Fact] + public void KeysOfNearest_WithFilter_AppliesFilter() + { + // Arrange + var list = new List> + { + new KeyValuePair(100, new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f })), + new KeyValuePair(200, new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f })), + new KeyValuePair(300, new MockEmbedding(new float[] { 0.8f, 0.2f, 0.0f })) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + var matches = TopNCollection.Create(10); + Func filter = (key) => key >= 200; // Only keys 200 and above + + // Act + list.KeysOfNearest(query, matches, double.MinValue, filter); + var results = matches.ByRankAndClear(); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.Item >= 200)); + } + + [Fact] + public void KeysOfNearest_ThrowsOnNullMatches() + { + // Arrange + var list = new List> + { + new KeyValuePair(100, new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f })) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + + // Act & Assert + Assert.Throws(() => + list.KeysOfNearest(query, null!)); + } + + [Fact] + public void KeysOfNearest_WithMaxMatches_ReturnsCorrectCount() + { + // Arrange + var list = new List> + { + new KeyValuePair(100, new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f })), + new KeyValuePair(200, new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f })), + new KeyValuePair(300, new MockEmbedding(new float[] { 0.8f, 0.2f, 0.0f })), + new KeyValuePair(400, new MockEmbedding(new float[] { 0.7f, 0.3f, 0.0f })) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + int maxMatches = 2; + + // Act + var results = list.KeysOfNearest(query, maxMatches); + + // Assert + Assert.Equal(2, results.Count); + Assert.Equal(100, results[0].Item); + Assert.Equal(200, results[1].Item); + } + + [Fact] + public void KeysOfNearest_WithMinScore_FiltersResults() + { + // Arrange + var list = new List> + { + new KeyValuePair(100, new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f })), + new KeyValuePair(200, new MockEmbedding(new float[] { 0.9f, 0.1f, 0.0f })), + new KeyValuePair(300, new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f })) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + double minScore = 0.8; + + // Act + var results = list.KeysOfNearest(query, 10, minScore); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.Score >= minScore)); + } + + [Fact] + public void KeysOfNearest_EmptyList_ReturnsEmpty() + { + // Arrange + var list = new List>(); + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + + // Act + var results = list.KeysOfNearest(query, 10); + + // Assert + Assert.Empty(results); + } + + [Fact] + public void IndexesOfNearest_EmptyList_ReturnsEmpty() + { + // Arrange + var list = new List(); + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + + // Act + var results = list.IndexesOfNearest(query, 10); + + // Assert + Assert.Empty(results); + } + + [Fact] + public void IndexesOfNearestInSubset_EmptySubset_ReturnsEmpty() + { + // Arrange + var list = new List + { + new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }), + new MockEmbedding(new float[] { 0.0f, 1.0f, 0.0f }) + }; + var query = new MockEmbedding(new float[] { 1.0f, 0.0f, 0.0f }); + var subset = new List(); + + // Act + var results = list.IndexesOfNearestInSubset(query, subset, 10); + + // Assert + Assert.Empty(results); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/EmbeddingTests.cs b/dotnet/typeagent/tests/typeChat.test/EmbeddingTests.cs new file mode 100644 index 000000000..abbe47d0f --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/EmbeddingTests.cs @@ -0,0 +1,421 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TypeAgent.Vector; + +namespace Microsoft.TypeChat.Tests; + +public class EmbeddingTests +{ + private static float[] CreateTestVector(params float[] values) + { + return values; + } + + private static float[] CreateRandomVector(int length, int seed = 0) + { + var random = new Random(seed); + var vector = new float[length]; + for (int i = 0; i < length; i++) + { + vector[i] = (float)random.NextDouble(); + } + return vector; + } + + [Fact] + public void Constructor_WithValidVector_CreatesEmbedding() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + + // Act + var embedding = new Embedding(vector); + + // Assert + Assert.NotNull(embedding.Vector); + Assert.Equal(3, embedding.Length); + Assert.Equal(vector, embedding.Vector); + } + + [Fact] + public void Constructor_WithNullVector_ThrowsArgumentNullException() + { + // Arrange + float[] vector = null; + + // Act & Assert + Assert.Throws(() => new Embedding(vector)); + } + + [Fact] + public void Empty_ReturnsEmptyEmbedding() + { + // Act + var empty = Embedding.Empty; + + // Assert + Assert.NotNull(empty.Vector); + Assert.Equal(0, empty.Length); + } + + [Fact] + public void Length_ReturnsVectorLength() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + var embedding = new Embedding(vector); + + // Act + var length = embedding.Length; + + // Assert + Assert.Equal(5, length); + } + + [Fact] + public void AsSpan_ReturnsReadOnlySpan() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var embedding = new Embedding(vector); + + // Act + var span = embedding.AsSpan(); + + // Assert + Assert.Equal(3, span.Length); + Assert.Equal(1.0f, span[0]); + Assert.Equal(2.0f, span[1]); + Assert.Equal(3.0f, span[2]); + } + + [Fact] + public void CosineSimilarity_WithIdenticalVectors_ReturnsOne() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var embedding1 = new Embedding(vector); + var embedding2 = new Embedding(vector); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(1.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithOrthogonalVectors_ReturnsZero() + { + // Arrange + var vector1 = CreateTestVector(1.0f, 0.0f); + var vector2 = CreateTestVector(0.0f, 1.0f); + var embedding1 = new Embedding(vector1); + var embedding2 = new Embedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(0.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithOppositeVectors_ReturnsNegativeOne() + { + // Arrange + var vector1 = CreateTestVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateTestVector(-1.0f, -2.0f, -3.0f); + var embedding1 = new Embedding(vector1); + var embedding2 = new Embedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(-1.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithSpan_CalculatesCorrectly() + { + // Arrange + var vector1 = CreateTestVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateTestVector(1.0f, 2.0f, 3.0f); + var embedding = new Embedding(vector1); + + // Act + var similarity = embedding.CosineSimilarity(vector2.AsSpan()); + + // Assert + Assert.Equal(1.0, similarity, precision: 5); + } + + [Fact] + public void DotProduct_WithVectors_CalculatesCorrectly() + { + // Arrange + var vector1 = CreateTestVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateTestVector(4.0f, 5.0f, 6.0f); + var embedding1 = new Embedding(vector1); + var embedding2 = new Embedding(vector2); + + // Act + var dotProduct = embedding1.DotProduct(embedding2); + + // Assert + // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32 + Assert.Equal(32.0, dotProduct, precision: 5); + } + + [Fact] + public void DotProduct_WithZeroVector_ReturnsZero() + { + // Arrange + var vector1 = CreateTestVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateTestVector(0.0f, 0.0f, 0.0f); + var embedding1 = new Embedding(vector1); + var embedding2 = new Embedding(vector2); + + // Act + var dotProduct = embedding1.DotProduct(embedding2); + + // Assert + Assert.Equal(0.0, dotProduct); + } + + [Fact] + public void ToNormalized_CreatesNormalizedEmbedding() + { + // Arrange + var vector = CreateTestVector(3.0f, 4.0f); // Length = 5 + var embedding = new Embedding(vector); + + // Act + var normalized = embedding.ToNormalized(); + + // Assert + Assert.NotNull(normalized.Vector); + Assert.Equal(2, normalized.Length); + Assert.Equal(0.6f, normalized.Vector[0], precision: 5); + Assert.Equal(0.8f, normalized.Vector[1], precision: 5); + } + + [Fact] + public void ToNormalized_DoesNotModifyOriginal() + { + // Arrange + var vector = CreateTestVector(3.0f, 4.0f); + var embedding = new Embedding(vector); + var originalValues = vector.ToArray(); + + // Act + var normalized = embedding.ToNormalized(); + + // Assert + Assert.Equal(originalValues, embedding.Vector); + } + + [Fact] + public void NormalizeInPlace_ModifiesVectorInPlace() + { + // Arrange + var vector = CreateTestVector(3.0f, 4.0f); // Length = 5 + var embedding = new Embedding(vector); + + // Act + embedding.NormalizeInPlace(); + + // Assert + Assert.Equal(0.6f, embedding.Vector[0], precision: 5); + Assert.Equal(0.8f, embedding.Vector[1], precision: 5); + } + + [Fact] + public void ToBytes_ConvertsVectorToBytes() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var embedding = new Embedding(vector); + + // Act + var bytes = embedding.ToBytes(); + + // Assert + Assert.NotNull(bytes); + Assert.Equal(vector.Length * sizeof(float), bytes.Length); + } + + [Fact] + public void ToBytes_Static_ConvertsVectorToBytes() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + + // Act + var bytes = Embedding.ToBytes(vector); + + // Assert + Assert.NotNull(bytes); + Assert.Equal(vector.Length * sizeof(float), bytes.Length); + } + + [Fact] + public void FromBytes_ReconstructsVector() + { + // Arrange + var originalVector = CreateTestVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(originalVector); + + // Act + var reconstructedVector = Embedding.FromBytes(bytes); + + // Assert + Assert.NotNull(reconstructedVector); + Assert.Equal(originalVector.Length, reconstructedVector.Length); + Assert.Equal(originalVector, reconstructedVector); + } + + [Fact] + public void ToBytes_FromBytes_RoundTrip_PreservesData() + { + // Arrange + var vector = CreateRandomVector(128, seed: 42); + var embedding = new Embedding(vector); + + // Act + var bytes = embedding.ToBytes(); + var reconstructed = Embedding.FromBytes(bytes); + + // Assert + Assert.Equal(vector, reconstructed); + } + + [Fact] + public void ImplicitConversion_ToFloatArray_Works() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var embedding = new Embedding(vector); + + // Act + float[] convertedArray = embedding; + + // Assert + Assert.Equal(vector, convertedArray); + } + + [Fact] + public void ImplicitConversion_FromFloatArray_Works() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + + // Act + Embedding embedding = vector; + + // Assert + Assert.Equal(vector, embedding.Vector); + } + + [Fact] + public void ImplicitConversion_ToReadOnlySpan_Works() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var embedding = new Embedding(vector); + + // Act + ReadOnlySpan span = embedding; + + // Assert + Assert.Equal(3, span.Length); + Assert.Equal(1.0f, span[0]); + Assert.Equal(2.0f, span[1]); + Assert.Equal(3.0f, span[2]); + } + + [Fact] + public void CosineSimilarity_WithSimilarVectors_ReturnsHighSimilarity() + { + // Arrange + var vector1 = CreateTestVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateTestVector(1.1f, 2.1f, 3.1f); + var embedding1 = new Embedding(vector1); + var embedding2 = new Embedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.True(similarity > 0.99, $"Expected similarity > 0.99, but got {similarity}"); + } + + [Fact] + public void CosineSimilarity_WithDifferentVectors_ReturnsLowSimilarity() + { + // Arrange + var vector1 = CreateTestVector(1.0f, 0.0f, 0.0f); + var vector2 = CreateTestVector(0.0f, 0.0f, 1.0f); + var embedding1 = new Embedding(vector1); + var embedding2 = new Embedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.True(similarity < 0.1, $"Expected similarity < 0.1, but got {similarity}"); + } + + [Fact] + public void Vector_Property_ReturnsOriginalVector() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + + // Act + var embedding = new Embedding(vector); + + // Assert + Assert.Same(vector, embedding.Vector); + } + + [Fact] + public void NormalizeInPlace_WithUnitVector_RemainsUnchanged() + { + // Arrange + var vector = CreateTestVector(1.0f, 0.0f, 0.0f); // Already unit vector + var embedding = new Embedding(vector); + + // Act + embedding.NormalizeInPlace(); + + // Assert + Assert.Equal(1.0f, embedding.Vector[0], precision: 5); + Assert.Equal(0.0f, embedding.Vector[1], precision: 5); + Assert.Equal(0.0f, embedding.Vector[2], precision: 5); + } + + [Fact] + public void ToNormalized_WithLargeVector_WorksCorrectly() + { + // Arrange + var vector = CreateRandomVector(1536, seed: 100); // Common embedding size + var embedding = new Embedding(vector); + + // Act + var normalized = embedding.ToNormalized(); + + // Assert + Assert.Equal(1536, normalized.Length); + // Verify it's normalized by checking the L2 norm is approximately 1 + var sumOfSquares = normalized.Vector.Sum(x => x * x); + Assert.Equal(1.0, sumOfSquares, precision: 4); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/JsonTests.cs b/dotnet/typeagent/tests/typeChat.test/JsonTests.cs new file mode 100644 index 000000000..837b92688 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/JsonTests.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.TypeChat.Tests; + +public class JsonTests +{ + private class TestClass + { + public int Id { get; set; } + public string? Name { get; set; } + } + + [Fact] + public void Stringify_Object_Indented_True() + { + var obj = new TestClass { Id = 1, Name = "Test" }; + string json = Microsoft.TypeChat.Json.Stringify(obj, true); + Assert.Contains(Environment.NewLine, json); + Assert.Contains("\"Id\"", json); + Assert.Contains("\"Name\"", json); + } + + [Fact] + public void Stringify_Object_Indented_False() + { + var obj = new TestClass { Id = 2, Name = "NoIndent" }; + string json = Microsoft.TypeChat.Json.Stringify(obj, false); + Assert.DoesNotContain(Environment.NewLine, json); + Assert.Contains("\"Id\"", json); + Assert.Contains("\"Name\"", json); + } + + [Fact] + public void Stringify_Generic_Indented_True() + { + var obj = new TestClass { Id = 3, Name = "Generic" }; + string json = Microsoft.TypeChat.Json.Stringify(obj, true); + Assert.Contains(Environment.NewLine, json); + Assert.Contains("\"Id\"", json); + Assert.Contains("\"Name\"", json); + } + + [Fact] + public void Parse_Object_From_Json_String() + { + string json = "{\"Id\":4,\"Name\":\"ParseTest\"}"; + var result = (TestClass?)Microsoft.TypeChat.Json.Parse(json, typeof(TestClass)); + Assert.NotNull(result); + Assert.Equal(4, result.Id); + Assert.Equal("ParseTest", result.Name); + } + + [Fact] + public void Parse_Generic_From_Json_String() + { + string json = "{\"Id\":5,\"Name\":\"GenericParse\"}"; + var result = Microsoft.TypeChat.Json.Parse(json); + Assert.NotNull(result); + Assert.Equal(5, result.Id); + Assert.Equal("GenericParse", result.Name); + } + + [Fact] + public void Parse_Generic_From_Stream() + { + var obj = new TestClass { Id = 6, Name = "StreamParse" }; + string json = Microsoft.TypeChat.Json.Stringify(obj, false); + using var stream = new MemoryStream(Encoding.UTF8.GetBytes(json)); + var result = Microsoft.TypeChat.Json.Parse(stream); + Assert.NotNull(result); + Assert.Equal(6, result.Id); + Assert.Equal("StreamParse", result.Name); + } + + [Fact] + public void DefaultOptions_Returns_Options() + { + var options = Microsoft.TypeChat.Json.DefaultOptions(); + Assert.NotNull(options); + Assert.True(options.WriteIndented == false || options.WriteIndented == true); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/JsonTranslatorTests.cs b/dotnet/typeagent/tests/typeChat.test/JsonTranslatorTests.cs new file mode 100644 index 000000000..5f8f6acc1 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/JsonTranslatorTests.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Threading; +using System.Threading.Tasks; +using TypeAgent.AIClient; +using TypeAgent.KnowPro; +using Xunit; + +namespace Microsoft.TypeChat.Tests; + +public class SampleType : SentimentResponse +{ + public string Name { get; set; } = string.Empty; + public int Value { get; set; } +} + +public class JsonTranslatorTests : TestWithData +{ + + private OpenAIChatModel _model; + private SchemaText _schema; + private JsonTranslator _translator; + private JsonSerializerTypeValidator _validator; + + public JsonTranslatorTests() : base(true) + { + _model = (OpenAIChatModel)ModelUtils.CreateTestChatModel(nameof(JsonTranslatorTests)); + _schema = SchemaText.Load("./SentimentSchema.ts"); + _translator = new JsonTranslator( + ModelUtils.CreateTestChatModel(nameof(JsonTranslatorTests)), + _schema + ); + + _validator = new JsonSerializerTypeValidator(_schema); + } + + [Fact] + public void Constructor_InitializesProperties() + { + Assert.Equal(_model.Settings.ModelName, ((OpenAIChatModel)_translator.Model).Settings.ModelName); + Assert.Equal(_model.Settings.Endpoint, ((OpenAIChatModel)_translator.Model).Settings.Endpoint); + Assert.Equal(_validator.Schema.TypeFullName, _translator.Validator.Schema.TypeFullName); + Assert.Equal(_validator.Schema.Schema, _translator.Validator.Schema.Schema); + Assert.NotNull(_translator.Prompts); + Assert.NotNull(_translator.TranslationSettings); + Assert.Equal(JsonTranslator.DefaultMaxRepairAttempts, _translator.MaxRepairAttempts); + } + + [Fact] + public void Validator_Setter_UpdatesValidator() + { + var model = ModelUtils.CreateTestChatModel(nameof(JsonTranslatorTests)); + var validator = new JsonSerializerTypeValidator(_schema); + var translator = new JsonTranslator(model, validator); + + translator.Validator = _validator; + Assert.Equal(_validator, translator.Validator); + } + + [Fact] + public void MaxRepairAttempts_Setter_HandlesNegativeValues() + { + _translator.MaxRepairAttempts = -5; + Assert.Equal(0, _translator.MaxRepairAttempts); + } + + [Fact] + public async Task TranslateAsync_ThrowsOnInvalidNoAsync() + { + var prompt = new Prompt("Test request"); + var mockModel = new MockModel_No_JSON_Response(); + + var translator = new JsonTranslator(mockModel, _schema) + { + MaxRepairAttempts = 1 + }; + + await Assert.ThrowsAsync(async () => + { + await translator.TranslateAsync(prompt, null, null, CancellationToken.None); + }); + } + + [Fact] + public async Task TranslateAsync_ThrowsOnInvalidJsonAsync() + { + var prompt = new Prompt("Test request"); + var mockModel = new MockModel_Partial_JSON_Response(); + + var translator = new JsonTranslator(mockModel, _schema) + { + MaxRepairAttempts = 1 + }; + + await Assert.ThrowsAsync(async () => + { + await translator.TranslateAsync(prompt, null, null, CancellationToken.None); + }); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/NormalizedEmbeddingBTests.cs b/dotnet/typeagent/tests/typeChat.test/NormalizedEmbeddingBTests.cs new file mode 100644 index 000000000..b7f00768f --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/NormalizedEmbeddingBTests.cs @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TypeAgent.Vector; + +namespace Microsoft.TypeChat.Tests; + +public class NormalizedEmbeddingBTests +{ + private static float[] CreateTestVector(params float[] values) + { + return values; + } + + private static float[] CreateRandomVector(int length, int seed = 0) + { + var random = new Random(seed); + var vector = new float[length]; + for (int i = 0; i < length; i++) + { + vector[i] = (float)random.NextDouble(); + } + return vector; + } + + private static float[] CreateNormalizedVector(params float[] values) + { + var sumOfSquares = values.Sum(x => x * x); + var magnitude = MathF.Sqrt(sumOfSquares); + return values.Select(x => x / magnitude).ToArray(); + } + + [Fact] + public void Constructor_WithValidByteArray_CreatesNormalizedEmbeddingB() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + + // Act + var embedding = new NormalizedEmbeddingB(bytes); + + // Assert + Assert.NotNull(embedding.Vector); + Assert.Equal(12, embedding.Length); // 3 floats * 4 bytes + Assert.Equal(bytes, embedding.Vector); + } + + [Fact] + public void Constructor_WithNullByteArray_ThrowsArgumentNullException() + { + // Arrange + byte[] bytes = null; + + // Act & Assert + Assert.Throws(() => new NormalizedEmbeddingB(bytes)); + } + + [Fact] + public void Length_ReturnsVectorLength() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + var bytes = Embedding.ToBytes(vector); + var embedding = new NormalizedEmbeddingB(bytes); + + // Act + var length = embedding.Length; + + // Assert + Assert.Equal(20, length); // 5 floats * 4 bytes + } + + [Fact] + public void AsSpan_ReturnsReadOnlySpanOfFloats() + { + // Arrange + var vector = CreateTestVector(0.6f, 0.8f); + var bytes = Embedding.ToBytes(vector); + var embedding = new NormalizedEmbeddingB(bytes); + + // Act + var span = embedding.AsSpan(); + + // Assert + Assert.Equal(2, span.Length); + Assert.Equal(0.6f, span[0]); + Assert.Equal(0.8f, span[1]); + } + + [Fact] + public void ToEmbedding_ConvertsToEmbedding() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + var embeddingB = new NormalizedEmbeddingB(bytes); + + // Act + var embedding = embeddingB.ToEmbedding(); + + // Assert + Assert.Equal(3, embedding.Length); + Assert.Equal(vector[0], embedding.Vector[0]); + Assert.Equal(vector[1], embedding.Vector[1]); + Assert.Equal(vector[2], embedding.Vector[2]); + } + + [Fact] + public void CosineSimilarity_WithIdenticalNormalizedEmbeddingB_ReturnsOne() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + var embedding1 = new NormalizedEmbeddingB(bytes); + var embedding2 = new NormalizedEmbeddingB(bytes); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(1.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithOrthogonalVectors_ReturnsZero() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 0.0f); + var vector2 = CreateNormalizedVector(0.0f, 1.0f); + var bytes1 = Embedding.ToBytes(vector1); + var bytes2 = Embedding.ToBytes(vector2); + var embedding1 = new NormalizedEmbeddingB(bytes1); + var embedding2 = new NormalizedEmbeddingB(bytes2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(0.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithOppositeVectors_ReturnsNegativeOne() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateNormalizedVector(-1.0f, -2.0f, -3.0f); + var bytes1 = Embedding.ToBytes(vector1); + var bytes2 = Embedding.ToBytes(vector2); + var embedding1 = new NormalizedEmbeddingB(bytes1); + var embedding2 = new NormalizedEmbeddingB(bytes2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(-1.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithNormalizedEmbedding_CalculatesCorrectly() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + var embeddingB = new NormalizedEmbeddingB(bytes); + var embedding = new NormalizedEmbedding(vector); + + // Act + var similarity = embeddingB.CosineSimilarity(embedding); + + // Assert + Assert.Equal(1.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_BetweenNormalizedEmbeddingBAndNormalizedEmbedding_IsSymmetric() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + var embeddingB = new NormalizedEmbeddingB(bytes); + var embedding = new NormalizedEmbedding(vector); + + // Act + var similarity1 = embeddingB.CosineSimilarity(embedding); + var similarity2 = embedding.CosineSimilarity(embeddingB); + + // Assert + Assert.Equal(similarity1, similarity2, precision: 10); + } + + [Fact] + public void CosineSimilarity_UsesOptimizedDotProduct() + { + // Arrange - normalized vectors, so cosine similarity = dot product + var vector1 = CreateNormalizedVector(3.0f, 4.0f); + var vector2 = CreateNormalizedVector(5.0f, 12.0f); + var bytes1 = Embedding.ToBytes(vector1); + var bytes2 = Embedding.ToBytes(vector2); + var embedding1 = new NormalizedEmbeddingB(bytes1); + var embedding2 = new NormalizedEmbeddingB(bytes2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + // For normalized vectors: (0.6, 0.8) · (0.3846..., 0.9230...) ≈ 0.9692 + Assert.True(similarity > 0.95 && similarity < 1.0); + } + + [Fact] + public void ImplicitConversion_ToByteArray_Works() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + var embedding = new NormalizedEmbeddingB(bytes); + + // Act + byte[] convertedArray = embedding; + + // Assert + Assert.Equal(bytes, convertedArray); + } + + [Fact] + public void ImplicitConversion_ToReadOnlySpan_Works() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + var embedding = new NormalizedEmbeddingB(bytes); + + // Act + ReadOnlySpan span = embedding; + + // Assert + Assert.Equal(3, span.Length); + Assert.Equal(vector[0], span[0], precision: 5); + Assert.Equal(vector[1], span[1], precision: 5); + Assert.Equal(vector[2], span[2], precision: 5); + } + + [Fact] + public void Vector_Property_ReturnsOriginalByteArray() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + + // Act + var embedding = new NormalizedEmbeddingB(bytes); + + // Assert + Assert.Same(bytes, embedding.Vector); + } + + [Fact] + public void CosineSimilarity_WithSimilarVectors_ReturnsHighSimilarity() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateNormalizedVector(1.1f, 2.1f, 3.1f); + var bytes1 = Embedding.ToBytes(vector1); + var bytes2 = Embedding.ToBytes(vector2); + var embedding1 = new NormalizedEmbeddingB(bytes1); + var embedding2 = new NormalizedEmbeddingB(bytes2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.True(similarity > 0.99, $"Expected similarity > 0.99, but got {similarity}"); + } + + [Fact] + public void CosineSimilarity_WithDifferentVectors_ReturnsLowSimilarity() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 0.0f, 0.0f); + var vector2 = CreateNormalizedVector(0.0f, 0.0f, 1.0f); + var bytes1 = Embedding.ToBytes(vector1); + var bytes2 = Embedding.ToBytes(vector2); + var embedding1 = new NormalizedEmbeddingB(bytes1); + var embedding2 = new NormalizedEmbeddingB(bytes2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.True(similarity < 0.1, $"Expected similarity < 0.1, but got {similarity}"); + } + + [Fact] + public void ToEmbedding_FromBytes_RoundTrip_PreservesData() + { + // Arrange + var vector = CreateRandomVector(128, seed: 42); + var bytes = Embedding.ToBytes(vector); + var embeddingB = new NormalizedEmbeddingB(bytes); + + // Act + var embedding = embeddingB.ToEmbedding(); + var reconstructedBytes = embedding.ToBytes(); + + // Assert + Assert.Equal(bytes, reconstructedBytes); + } + + [Fact] + public void AsSpan_CanBeUsedInCalculations() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector1); + var embedding = new NormalizedEmbeddingB(bytes); + + // Act + var span = embedding.AsSpan(); + var dotProduct = 0.0f; + for (int i = 0; i < span.Length; i++) + { + dotProduct += span[i] * vector2[i]; + } + + // Assert + Assert.Equal(1.0, dotProduct, precision: 5); + } + + [Fact] + public void Constructor_WithLargeVector_WorksCorrectly() + { + // Arrange + var vector = CreateRandomVector(1536, seed: 100); // Common embedding size + var bytes = Embedding.ToBytes(vector); + + // Act + var embedding = new NormalizedEmbeddingB(bytes); + + // Assert + Assert.Equal(6144, embedding.Length); // 1536 floats * 4 bytes + var span = embedding.AsSpan(); + Assert.Equal(1536, span.Length); + } + + [Fact] + public void ToEmbedding_WithLargeVector_PreservesAllValues() + { + // Arrange + var vector = CreateRandomVector(1536, seed: 200); + var bytes = Embedding.ToBytes(vector); + var embeddingB = new NormalizedEmbeddingB(bytes); + + // Act + var embedding = embeddingB.ToEmbedding(); + + // Assert + Assert.Equal(vector.Length, embedding.Length); + for (int i = 0; i < vector.Length; i++) + { + Assert.Equal(vector[i], embedding.Vector[i], precision: 5); + } + } + + [Fact] + public void AsSpan_MultipleCallsReturnSameData() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + var embedding = new NormalizedEmbeddingB(bytes); + + // Act + var span1 = embedding.AsSpan(); + var span2 = embedding.AsSpan(); + + // Assert + Assert.Equal(span1.Length, span2.Length); + for (int i = 0; i < span1.Length; i++) + { + Assert.Equal(span1[i], span2[i]); + } + } + + [Fact] + public void CosineSimilarity_WithEmptyVectors_HandlesGracefully() + { + // Arrange + var vector = CreateTestVector(); + var bytes = Embedding.ToBytes(vector); + var embedding1 = new NormalizedEmbeddingB(bytes); + var embedding2 = new NormalizedEmbeddingB(bytes); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(0.0, similarity); + } + + [Fact] + public void NormalizedEmbeddingB_StoresDataAsByteArray() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + var bytes = Embedding.ToBytes(vector); + + // Act + var embedding = new NormalizedEmbeddingB(bytes); + + // Assert + Assert.IsType(embedding.Vector); + Assert.Equal(bytes.Length, embedding.Vector.Length); + } + + [Fact] + public void CosineSimilarity_IsCommutative() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateNormalizedVector(4.0f, 5.0f, 6.0f); + var bytes1 = Embedding.ToBytes(vector1); + var bytes2 = Embedding.ToBytes(vector2); + var embedding1 = new NormalizedEmbeddingB(bytes1); + var embedding2 = new NormalizedEmbeddingB(bytes2); + + // Act + var similarity1to2 = embedding1.CosineSimilarity(embedding2); + var similarity2to1 = embedding2.CosineSimilarity(embedding1); + + // Assert + Assert.Equal(similarity1to2, similarity2to1, precision: 10); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/NormalizedEmbeddingTests.cs b/dotnet/typeagent/tests/typeChat.test/NormalizedEmbeddingTests.cs new file mode 100644 index 000000000..24345d1c7 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/NormalizedEmbeddingTests.cs @@ -0,0 +1,421 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TypeAgent.Vector; + +namespace Microsoft.TypeChat.Tests; + +public class NormalizedEmbeddingTests +{ + private static float[] CreateTestVector(params float[] values) + { + return values; + } + + private static float[] CreateRandomVector(int length, int seed = 0) + { + var random = new Random(seed); + var vector = new float[length]; + for (int i = 0; i < length; i++) + { + vector[i] = (float)random.NextDouble(); + } + return vector; + } + + private static float[] CreateNormalizedVector(params float[] values) + { + var sumOfSquares = values.Sum(x => x * x); + var magnitude = MathF.Sqrt(sumOfSquares); + return values.Select(x => x / magnitude).ToArray(); + } + + [Fact] + public void Constructor_WithValidVector_CreatesNormalizedEmbedding() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f); + + // Act + var embedding = new NormalizedEmbedding(vector); + + // Assert + Assert.NotNull(embedding.Vector); + Assert.Equal(3, embedding.Length); + Assert.Equal(vector, embedding.Vector); + } + + [Fact] + public void Constructor_WithNullVector_ThrowsArgumentNullException() + { + // Arrange + float[] vector = null; + + // Act & Assert + Assert.Throws(() => new NormalizedEmbedding(vector)); + } + + [Fact] + public void Length_ReturnsVectorLength() + { + // Arrange + var vector = CreateTestVector(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + var embedding = new NormalizedEmbedding(vector); + + // Act + var length = embedding.Length; + + // Assert + Assert.Equal(5, length); + } + + [Fact] + public void AsSpan_ReturnsReadOnlySpan() + { + // Arrange + var vector = CreateTestVector(0.6f, 0.8f); + var embedding = new NormalizedEmbedding(vector); + + // Act + var span = embedding.AsSpan(); + + // Assert + Assert.Equal(2, span.Length); + Assert.Equal(0.6f, span[0]); + Assert.Equal(0.8f, span[1]); + } + + [Fact] + public void CosineSimilarity_WithIdenticalVectors_ReturnsOne() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var embedding1 = new NormalizedEmbedding(vector); + var embedding2 = new NormalizedEmbedding(vector); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(1.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithOrthogonalVectors_ReturnsZero() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 0.0f); + var vector2 = CreateNormalizedVector(0.0f, 1.0f); + var embedding1 = new NormalizedEmbedding(vector1); + var embedding2 = new NormalizedEmbedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(0.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithOppositeVectors_ReturnsNegativeOne() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateNormalizedVector(-1.0f, -2.0f, -3.0f); + var embedding1 = new NormalizedEmbedding(vector1); + var embedding2 = new NormalizedEmbedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.Equal(-1.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_WithNormalizedEmbeddingB_CalculatesCorrectly() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var embedding = new NormalizedEmbedding(vector); + var bytes = Embedding.ToBytes(vector); + var embeddingB = new NormalizedEmbeddingB(bytes); + + // Act + var similarity = embedding.CosineSimilarity(embeddingB); + + // Assert + Assert.Equal(1.0, similarity, precision: 5); + } + + [Fact] + public void CosineSimilarity_UsesOptimizedDotProduct() + { + // Arrange - normalized vectors, so cosine similarity = dot product + var vector1 = CreateNormalizedVector(3.0f, 4.0f); + var vector2 = CreateNormalizedVector(5.0f, 12.0f); + var embedding1 = new NormalizedEmbedding(vector1); + var embedding2 = new NormalizedEmbedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + // For normalized vectors: (0.6, 0.8) · (0.3846..., 0.9230...) ≈ 0.9692 + Assert.True(similarity > 0.95 && similarity < 1.0); + } + + [Fact] + public void ToBytes_ConvertsVectorToBytes() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var embedding = new NormalizedEmbedding(vector); + + // Act + var bytes = embedding.ToBytes(); + + // Assert + Assert.NotNull(bytes); + Assert.Equal(vector.Length * sizeof(float), bytes.Length); + } + + [Fact] + public void FromArray_WithNormalize_CreatesNormalizedEmbedding() + { + // Arrange + var vector = CreateTestVector(3.0f, 4.0f); // Length = 5 + + // Act + var embedding = NormalizedEmbedding.FromArray(vector, normalize: true); + + // Assert + Assert.NotNull(embedding.Vector); + Assert.Equal(2, embedding.Length); + Assert.Equal(0.6f, embedding.Vector[0], precision: 5); + Assert.Equal(0.8f, embedding.Vector[1], precision: 5); + } + + [Fact] + public void FromArray_WithoutNormalize_CreatesEmbeddingWithOriginalVector() + { + // Arrange + var vector = CreateTestVector(3.0f, 4.0f); + + // Act + var embedding = NormalizedEmbedding.FromArray(vector, normalize: false); + + // Assert + Assert.NotNull(embedding.Vector); + Assert.Equal(2, embedding.Length); + Assert.Equal(3.0f, embedding.Vector[0]); + Assert.Equal(4.0f, embedding.Vector[1]); + } + + [Fact] + public void FromArray_DefaultParameter_NormalizesVector() + { + // Arrange + var vector = CreateTestVector(3.0f, 4.0f); + + // Act + var embedding = NormalizedEmbedding.FromArray(vector); + + // Assert + Assert.Equal(0.6f, embedding.Vector[0], precision: 5); + Assert.Equal(0.8f, embedding.Vector[1], precision: 5); + } + + [Fact] + public void ImplicitConversion_ToFloatArray_Works() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var embedding = new NormalizedEmbedding(vector); + + // Act + float[] convertedArray = embedding; + + // Assert + Assert.Equal(vector, convertedArray); + } + + [Fact] + public void ImplicitConversion_ToReadOnlySpan_Works() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var embedding = new NormalizedEmbedding(vector); + + // Act + ReadOnlySpan span = embedding; + + // Assert + Assert.Equal(3, span.Length); + Assert.Equal(vector[0], span[0], precision: 5); + Assert.Equal(vector[1], span[1], precision: 5); + Assert.Equal(vector[2], span[2], precision: 5); + } + + [Fact] + public void ToBytes_FromBytes_RoundTrip_PreservesData() + { + // Arrange + var vector = CreateRandomVector(128, seed: 42); + var embedding = NormalizedEmbedding.FromArray(vector, normalize: true); + + // Act + var bytes = embedding.ToBytes(); + var reconstructed = Embedding.FromBytes(bytes); + + // Assert + Assert.Equal(embedding.Vector, reconstructed); + } + + [Fact] + public void Vector_Property_ReturnsOriginalVector() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + + // Act + var embedding = new NormalizedEmbedding(vector); + + // Assert + Assert.Same(vector, embedding.Vector); + } + + [Fact] + public void CosineSimilarity_WithSimilarVectors_ReturnsHighSimilarity() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateNormalizedVector(1.1f, 2.1f, 3.1f); + var embedding1 = new NormalizedEmbedding(vector1); + var embedding2 = new NormalizedEmbedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.True(similarity > 0.99, $"Expected similarity > 0.99, but got {similarity}"); + } + + [Fact] + public void CosineSimilarity_WithDifferentVectors_ReturnsLowSimilarity() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 0.0f, 0.0f); + var vector2 = CreateNormalizedVector(0.0f, 0.0f, 1.0f); + var embedding1 = new NormalizedEmbedding(vector1); + var embedding2 = new NormalizedEmbedding(vector2); + + // Act + var similarity = embedding1.CosineSimilarity(embedding2); + + // Assert + Assert.True(similarity < 0.1, $"Expected similarity < 0.1, but got {similarity}"); + } + + [Fact] + public void FromArray_WithLargeVector_WorksCorrectly() + { + // Arrange + var vector = CreateRandomVector(1536, seed: 100); // Common embedding size + + // Act + var embedding = NormalizedEmbedding.FromArray(vector, normalize: true); + + // Assert + Assert.Equal(1536, embedding.Length); + // Verify it's normalized by checking the L2 norm is approximately 1 + var sumOfSquares = embedding.Vector.Sum(x => x * x); + Assert.Equal(1.0, sumOfSquares, precision: 4); + } + + [Fact] + public void FromArray_WithZeroVector_HandlesGracefully() + { + // Arrange + var vector = CreateTestVector(0.0f, 0.0f, 0.0f); + + // Act + var embedding = NormalizedEmbedding.FromArray(vector, normalize: true); + + // Assert + Assert.Equal(3, embedding.Length); + // Normalization of zero vector results in NaN or Infinity + Assert.True(float.IsNaN(embedding.Vector[0]) || float.IsInfinity(embedding.Vector[0])); + } + + [Fact] + public void CosineSimilarity_BetweenNormalizedAndNormalizedB_IsSymmetric() + { + // Arrange + var vector = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var embedding = new NormalizedEmbedding(vector); + var bytes = Embedding.ToBytes(vector); + var embeddingB = new NormalizedEmbeddingB(bytes); + + // Act + var similarity1 = embedding.CosineSimilarity(embeddingB); + var similarity2 = embeddingB.CosineSimilarity(embedding); + + // Assert + Assert.Equal(similarity1, similarity2, precision: 10); + } + + [Fact] + public void FromArray_WithUnitVector_RemainsUnchanged() + { + // Arrange + var vector = CreateTestVector(1.0f, 0.0f, 0.0f); // Already unit vector + + // Act + var embedding = NormalizedEmbedding.FromArray(vector, normalize: true); + + // Assert + Assert.Equal(1.0f, embedding.Vector[0], precision: 5); + Assert.Equal(0.0f, embedding.Vector[1], precision: 5); + Assert.Equal(0.0f, embedding.Vector[2], precision: 5); + } + + [Fact] + public void AsSpan_CanBeUsedInCalculations() + { + // Arrange + var vector1 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var vector2 = CreateNormalizedVector(1.0f, 2.0f, 3.0f); + var embedding = new NormalizedEmbedding(vector1); + + // Act + var span = embedding.AsSpan(); + var dotProduct = 0.0f; + for (int i = 0; i < span.Length; i++) + { + dotProduct += span[i] * vector2[i]; + } + + // Assert + Assert.Equal(1.0, dotProduct, precision: 5); + } + + [Fact] + public void FromArray_WithEmptyVector_CreatesEmptyEmbedding() + { + // Arrange + var vector = CreateTestVector(); + + // Act + var embedding = NormalizedEmbedding.FromArray(vector, normalize: false); + + // Assert + Assert.Equal(0, embedding.Length); + Assert.NotNull(embedding.Vector); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/PromptTests.cs b/dotnet/typeagent/tests/typeChat.test/PromptTests.cs new file mode 100644 index 000000000..9d57f12c5 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/PromptTests.cs @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.TypeChat; +using Xunit; + +namespace Microsoft.TypeChat.Tests; + +public class PromptTests +{ + [Fact] + public void Constructor_EmptyPrompt_CreatesEmptyList() + { + var prompt = new Prompt(); + Assert.Empty(prompt); + } + + [Fact] + public void Constructor_PromptSection_AddsSection() + { + var section = new PromptSection("user", "Hello"); + var prompt = new Prompt(section); + Assert.Single(prompt); + Assert.Equal(section, prompt[0]); + } + + [Fact] + public void Constructor_PreambleTextPostamble_AddsAllSections() + { + var preamble = new[] { new PromptSection("system", "Pre") }; + var text = new PromptSection("user", "Main"); + var postamble = new[] { new PromptSection("assistant", "Post") }; + + var prompt = new Prompt(preamble, text, postamble); + + Assert.Equal(3, prompt.Count); + Assert.Equal("Pre", prompt[0].GetText()); + Assert.Equal("Main", prompt[1].GetText()); + Assert.Equal("Post", prompt[2].GetText()); + } + + [Fact] + public void Add_NullSection_ThrowsArgumentNullException() + { + var prompt = new Prompt(); + Assert.Throws(() => prompt.Add(null)); + } + + [Fact] + public void Append_StringSourceSection_AddsSection() + { + var prompt = new Prompt(); + prompt.Append("user", "Hello"); + Assert.Single(prompt); + Assert.Equal("Hello", prompt[0].GetText()); + Assert.Equal("user", prompt[0].Source); + } + + [Fact] + public void Append_StringSection_AddsUserSection() + { + var prompt = new Prompt(); + prompt.Append("Hello"); + Assert.Single(prompt); + Assert.Equal("Hello", prompt[0].GetText()); + Assert.Equal(PromptSection.Sources.User, prompt[0].Source); + } + + [Fact] + public void AppendInstruction_AddsSystemSection() + { + var prompt = new Prompt(); + prompt.AppendInstruction("Do this"); + Assert.Single(prompt); + Assert.Equal("Do this", prompt[0].GetText()); + Assert.Equal(PromptSection.Sources.System, prompt[0].Source); + } + + [Fact] + public void AppendResponse_AddsAssistantSection() + { + var prompt = new Prompt(); + prompt.AppendResponse("Response"); + Assert.Single(prompt); + Assert.Equal("Response", prompt[0].GetText()); + Assert.Equal(PromptSection.Sources.Assistant, prompt[0].Source); + } + + [Fact] + public void Append_PromptSection_AddsSection() + { + var prompt = new Prompt(); + var section = new PromptSection("user", "Hello"); + prompt.Append(section); + Assert.Single(prompt); + Assert.Equal(section, prompt[0]); + } + + [Fact] + public void Append_EnumerableSections_AddsAll() + { + var prompt = new Prompt(); + var sections = new[] + { + new PromptSection("user", "A"), + new PromptSection("system", "B") + }; + prompt.Append(sections); + Assert.Equal(2, prompt.Count); + Assert.Equal("A", prompt[0].GetText()); + Assert.Equal("B", prompt[1].GetText()); + } + + [Fact] + public void Append_Prompt_AddsAllSections() + { + var prompt1 = new Prompt(); + prompt1.Append("Hello"); + var prompt2 = new Prompt(); + prompt2.Append("World"); + prompt1.Append(prompt2); + Assert.Equal(2, prompt1.Count); + Assert.Equal("Hello", prompt1[0].GetText()); + Assert.Equal("World", prompt1[1].GetText()); + } + + [Fact] + public void Last_ReturnsLastSectionOrNull() + { + var prompt = new Prompt(); + Assert.Null(prompt.Last()); + var section = new PromptSection("user", "Hello"); + prompt.Append(section); + Assert.Equal(section, prompt.Last()); + } + + [Fact] + public void JoinSections_ConcatenatesSections() + { + var prompt = new Prompt(); + prompt.Append("A"); + prompt.Append("B"); + var sb = prompt.JoinSections(",", false); + Assert.Equal("A,B,", sb.ToString()); + } + + [Fact] + public void JoinSections_IncludeSource_ConcatenatesWithSource() + { + var prompt = new Prompt(); + prompt.AppendInstruction("Do"); + prompt.Append("Say"); + var sb = prompt.JoinSections("|", true); + Assert.Contains("system:", sb.ToString()); + Assert.Contains("user:", sb.ToString()); + } + + [Fact] + public void ToString_ConcatenatesSections() + { + var prompt = new Prompt(); + prompt.Append("A"); + prompt.Append("B"); + Assert.Equal("A\nB\n", prompt.ToString()); + } + + [Fact] + public void ToString_IncludeSource_ConcatenatesWithSource() + { + var prompt = new Prompt(); + prompt.AppendInstruction("Do"); + prompt.Append("Say"); + var result = prompt.ToString(true); + Assert.Contains("system:", result); + Assert.Contains("user:", result); + } + + [Fact] + public void GetLength_ReturnsTotalLength() + { + var prompt = new Prompt(); + prompt.Append("A"); + prompt.Append("BC"); + Assert.Equal(3, prompt.GetLength()); + } + + [Fact] + public void ImplicitOperator_PromptFromString() + { + Prompt prompt = "Hello"; + Assert.Single(prompt); + Assert.Equal("Hello", prompt[0].GetText()); + } + + [Fact] + public void ImplicitOperator_StringFromPrompt() + { + var prompt = new Prompt(); + prompt.Append("Hello"); + string result = prompt; + Assert.Equal("Hello\n", result); + } + + [Fact] + public void OperatorPlus_AppendsSection() + { + var prompt = new Prompt(); + var section = new PromptSection("user", "Hello"); + var result = prompt + section; + Assert.Single(result); + Assert.Equal(section, result[0]); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/SchemaTests.cs b/dotnet/typeagent/tests/typeChat.test/SchemaTests.cs new file mode 100644 index 000000000..17beaf48c --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/SchemaTests.cs @@ -0,0 +1,635 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.TypeChat.Tests; + +public class SchemaTests +{ + #region Enum Tests + + [Fact] + public void Coffees_Enum_HasExpectedValues() + { + Assert.Equal(4, Enum.GetValues().Length); + Assert.True(Enum.IsDefined(typeof(Coffees), Coffees.Coffee)); + Assert.True(Enum.IsDefined(typeof(Coffees), Coffees.Latte)); + Assert.True(Enum.IsDefined(typeof(Coffees), Coffees.Mocha)); + Assert.True(Enum.IsDefined(typeof(Coffees), Coffees.Unknown)); + } + + [Fact] + public void CoffeeSize_Enum_HasExpectedValues() + { + Assert.Equal(5, Enum.GetValues().Length); + Assert.True(Enum.IsDefined(typeof(CoffeeSize), CoffeeSize.Small)); + Assert.True(Enum.IsDefined(typeof(CoffeeSize), CoffeeSize.Medium)); + Assert.True(Enum.IsDefined(typeof(CoffeeSize), CoffeeSize.Large)); + Assert.True(Enum.IsDefined(typeof(CoffeeSize), CoffeeSize.Grande)); + Assert.True(Enum.IsDefined(typeof(CoffeeSize), CoffeeSize.Venti)); + } + + #endregion + + #region CoffeeOrder Tests + + [Fact] + public void CoffeeOrder_CanBeCreated() + { + var order = new CoffeeOrder + { + Coffee = Coffees.Latte, + Quantity = 2, + Size = CoffeeSize.Grande + }; + + Assert.Equal(Coffees.Latte, order.Coffee); + Assert.Equal(2, order.Quantity); + Assert.Equal(CoffeeSize.Grande, order.Size); + } + + [Fact] + public void CoffeeOrder_Serializes_WithCorrectPropertyNames() + { + var order = new CoffeeOrder + { + Coffee = Coffees.Mocha, + Quantity = 1, + Size = CoffeeSize.Large + }; + + var json = System.Text.Json.JsonSerializer.Serialize(order); + Assert.Contains("\"coffee\"", json); + Assert.Contains("\"quantity\"", json); + Assert.Contains("\"size\"", json); + } + + #endregion + + #region Creamer and Milk Tests + + [Fact] + public void Creamer_HasNameProperty() + { + var creamer = new Creamer { Name = "half and half" }; + Assert.Equal("half and half", creamer.Name); + } + + [Fact] + public void Milk_SerializesWithCorrectPropertyName() + { + var milk = new Milk { Name = "whole milk" }; + var json = System.Text.Json.JsonSerializer.Serialize(milk); + Assert.Contains("\"name\"", json); + } + + #endregion + + #region DessertOrder Tests + + [Fact] + public void DessertOrder_DefaultConstructor_SetsDefaultQuantity() + { + var dessert = new DessertOrder { Name = "Tiramisu" }; + Assert.Equal(1, dessert.Quantity); + } + + [Fact] + public void DessertOrder_ParameterizedConstructor_SetsValues() + { + var dessert = new DessertOrder("Chocolate Cake", 3); + Assert.Equal("Chocolate Cake", dessert.Name); + Assert.Equal(3, dessert.Quantity); + } + + [Fact] + public void DessertOrder_ImplicitConversion_FromString() + { + DessertOrder dessert = "Tiramisu"; + Assert.Equal("Tiramisu", dessert.Name); + Assert.Equal(1, dessert.Quantity); + } + + [Fact] + public void DessertOrder_Serializes_WithCorrectPropertyNames() + { + var dessert = new DessertOrder("Coffee Cake", 2); + var json = System.Text.Json.JsonSerializer.Serialize(dessert); + Assert.Contains("\"dessert\"", json); + Assert.Contains("\"quantity\"", json); + } + + #endregion + + #region FruitOrder Tests + + [Fact] + public void FruitOrder_CanBeCreated() + { + var fruit = new FruitOrder { Name = "Banana", Quantity = 5 }; + Assert.Equal("Banana", fruit.Name); + Assert.Equal(5, fruit.Quantity); + } + + [Fact] + public void FruitOrder_Serializes_WithCorrectPropertyNames() + { + var fruit = new FruitOrder { Name = "Regular Apple", Quantity = 3 }; + var json = System.Text.Json.JsonSerializer.Serialize(fruit); + Assert.Contains("\"fruit\"", json); + Assert.Contains("\"quantity\"", json); + } + + #endregion + + #region UnknownItem Tests + + [Fact] + public void UnknownItem_StoresText() + { + var unknown = new UnknownItem { Text = "something unclear" }; + Assert.Equal("something unclear", unknown.Text); + } + + #endregion + + #region Order Tests + + [Fact] + public void Order_CanContainMultipleOrderTypes() + { + var order = new Order + { + Coffees = new[] { new CoffeeOrder { Coffee = Coffees.Latte, Quantity = 1, Size = CoffeeSize.Medium } }, + Desserts = new[] { new DessertOrder("Tiramisu", 2) }, + Fruits = new[] { new FruitOrder { Name = "Banana", Quantity = 3 } }, + Unknown = new[] { new UnknownItem { Text = "unclear item" } } + }; + + Assert.NotNull(order.Coffees); + Assert.Single(order.Coffees); + Assert.NotNull(order.Desserts); + Assert.Single(order.Desserts); + Assert.NotNull(order.Fruits); + Assert.Single(order.Fruits); + Assert.NotNull(order.Unknown); + Assert.Single(order.Unknown); + } + + [Fact] + public void Order_AllowsNullCollections() + { + var order = new Order { Desserts = Array.Empty() }; + Assert.Null(order.Coffees); + Assert.Null(order.Fruits); + Assert.Null(order.Unknown); + } + + #endregion + + #region SentimentResponse Tests + + [Fact] + public void SentimentResponse_CanStoreSentiment() + { + var response = new SentimentResponse { Sentiment = "positive" }; + Assert.Equal("positive", response.Sentiment); + } + + #endregion + + #region NullableTestObj Tests + + [Fact] + public void NullableTestObj_SupportsNullableAndRequiredFields() + { + var obj = new NullableTestObj + { + Required = CoffeeSize.Large, + Optional = null, + Text = "test", + OptionalText = null, + OptionalTextField = "field", + Amt = 100, + OptionalAmt = null + }; + + Assert.Equal(CoffeeSize.Large, obj.Required); + Assert.Null(obj.Optional); + Assert.Equal("test", obj.Text); + Assert.Null(obj.OptionalText); + Assert.Equal("field", obj.OptionalTextField); + Assert.Equal(100, obj.Amt); + Assert.Null(obj.OptionalAmt); + } + + [Fact] + public void NullableTestObj_CanSetOptionalValues() + { + var obj = new NullableTestObj + { + Required = CoffeeSize.Small, + Optional = CoffeeSize.Medium, + Text = "required", + OptionalText = "optional", + Amt = 50, + OptionalAmt = 25 + }; + + Assert.Equal(CoffeeSize.Medium, obj.Optional); + Assert.Equal("optional", obj.OptionalText); + Assert.Equal(25, obj.OptionalAmt); + } + + #endregion + + #region WrapperNullableObj Tests + + [Fact] + public void WrapperNullableObj_CanWrapNullableTest() + { + var wrapper = new WrapperNullableObj + { + Test = new NullableTestObj { Required = CoffeeSize.Grande, Amt = 10, Text = "test" }, + OptionalMilk = "whole milk" + }; + + Assert.NotNull(wrapper.Test); + Assert.Equal("whole milk", wrapper.OptionalMilk); + } + + [Fact] + public void WrapperNullableObj_AllowsNullValues() + { + var wrapper = new WrapperNullableObj { Test = null, OptionalMilk = null }; + Assert.Null(wrapper.Test); + Assert.Null(wrapper.OptionalMilk); + } + + #endregion + + #region ConverterTestObj and HardcodedVocabObj Tests + + [Fact] + public void ConverterTestObj_HasMilkProperty() + { + var obj = new ConverterTestObj { Milk = "Almond" }; + Assert.Equal("Almond", obj.Milk); + } + + [Fact] + public void HardcodedVocabObj_HasVocabName() + { + Assert.Equal("Local", HardcodedVocabObj.VocabName); + } + + [Fact] + public void HardcodedVocabObj_CanSetValue() + { + var obj = new HardcodedVocabObj { Value = "Two" }; + Assert.Equal("Two", obj.Value); + } + + #endregion + + #region JsonFunc and JsonExpr Tests + + [Fact] + public void JsonFunc_HasNameProperty() + { + var func = new JsonFunc { Name = "testFunc" }; + Assert.Equal("testFunc", func.Name); + } + + [Fact] + public void JsonExpr_CanContainFuncAndValue() + { + var func = new JsonFunc { Name = "add" }; + var value = System.Text.Json.JsonDocument.Parse("42").RootElement; + var expr = new JsonExpr { Func = func, Value = value }; + + Assert.NotNull(expr.Func); + Assert.Equal("add", expr.Func.Name); + Assert.Equal(System.Text.Json.JsonValueKind.Number, expr.Value.ValueKind); + } + + #endregion + + #region TestVocabs Tests + + [Fact] + public void TestVocabs_Names_HasExpectedConstants() + { + Assert.Equal("Desserts", TestVocabs.Names.Desserts); + Assert.Equal("Fruits", TestVocabs.Names.Fruits); + Assert.Equal("Milks", TestVocabs.Names.Milks); + Assert.Equal("Creamers", TestVocabs.Names.Creamers); + } + + [Fact] + public void TestVocabs_Desserts_ReturnsVocab() + { + var vocab = TestVocabs.Desserts(); + Assert.Equal(TestVocabs.Names.Desserts, vocab.Name); + Assert.NotNull(vocab.Vocab); + Assert.Contains("Tiramisu", vocab.Vocab); + Assert.Contains("Chocolate Cake", vocab.Vocab); + } + + [Fact] + public void TestVocabs_Fruits_ReturnsVocab() + { + var vocab = TestVocabs.Fruits(); + Assert.Equal(TestVocabs.Names.Fruits, vocab.Name); + Assert.Contains("Banana", vocab.Vocab); + Assert.Contains("Regular Apple", vocab.Vocab); + } + + [Fact] + public void TestVocabs_Milks_ReturnsVocab() + { + var vocab = TestVocabs.Milks(); + Assert.Equal(TestVocabs.Names.Milks, vocab.Name); + Assert.Contains("whole milk", vocab.Vocab); + Assert.Contains("almond milk", vocab.Vocab); + } + + [Fact] + public void TestVocabs_Creamers_ReturnsVocab() + { + var vocab = TestVocabs.Creamers(); + Assert.Equal(TestVocabs.Names.Creamers, vocab.Name); + Assert.Contains("half and half", vocab.Vocab); + Assert.Contains("heavy cream", vocab.Vocab); + } + + [Fact] + public void TestVocabs_All_ReturnsAllVocabs() + { + var allVocabs = TestVocabs.All(); + Assert.Equal(4, allVocabs.Count); + } + + #endregion + + #region Person, Name, and Location Tests + + [Fact] + public void Person_CanBeCreated() + { + var person = new Person + { + Name = new Name { FirstName = "John", LastName = "Doe" }, + Age = 30, + Location = new Location { City = "Seattle", State = "WA", Country = "USA" } + }; + + Assert.Equal("John", person.Name.FirstName); + Assert.Equal(30, person.Age); + Assert.Equal("Seattle", person.Location.City); + } + + [Fact] + public void Person_HasSameName_ComparesCorrectly() + { + var person1 = new Person { Name = new Name { FirstName = "Jane", LastName = "Smith" }, Age = 25 }; + var person2 = new Person { Name = new Name { FirstName = "Jane", LastName = "Smith" }, Age = 30 }; + var person3 = new Person { Name = new Name { FirstName = "John", LastName = "Smith" }, Age = 25 }; + + Assert.True(person1.HasSameName(person2)); + Assert.False(person1.HasSameName(person3)); + } + + [Fact] + public void Person_ChangeCase_ModifiesNameAndLocation() + { + var person = new Person + { + Name = new Name { FirstName = "John", LastName = "Doe" }, + Age = 30, + Location = new Location { City = "Seattle", State = "WA", Country = "USA" } + }; + + person.ChangeCase(true); + Assert.Equal("JOHN", person.Name.FirstName); + Assert.Equal("DOE", person.Name.LastName); + Assert.Equal("SEATTLE", person.Location.City); + + person.ChangeCase(false); + Assert.Equal("john", person.Name.FirstName); + Assert.Equal("seattle", person.Location.City); + } + + [Fact] + public void Name_CompareTo_WorksCorrectly() + { + var name1 = new Name { FirstName = "Alice", LastName = "Brown" }; + var name2 = new Name { FirstName = "Bob", LastName = "Brown" }; + var name3 = new Name { FirstName = "Alice", LastName = "Brown" }; + + Assert.True(name1.CompareTo(name2) < 0); + Assert.Equal(0, name1.CompareTo(name3)); + Assert.True(name2.CompareTo(name1) > 0); + } + + [Fact] + public void Name_ToString_FormatsCorrectly() + { + var name = new Name { FirstName = "Jane", LastName = "Doe" }; + Assert.Equal("Jane Doe", name.ToString()); + } + + [Fact] + public void Location_ChangeCase_ModifiesAllFields() + { + var location = new Location { City = "Portland", State = "OR", Country = "USA" }; + + location.ChangeCase(true); + Assert.Equal("PORTLAND", location.City); + Assert.Equal("OR", location.State); + Assert.Equal("USA", location.Country); + } + + #endregion + + #region AuthorPerson and FriendsOfPerson Tests + + [Fact] + public void AuthorPerson_CanStoreBooks() + { + var author = new AuthorPerson + { + Name = new Name { FirstName = "Isaac", LastName = "Asimov" }, + Books = new[] { "Foundation", "I, Robot" } + }; + + Assert.Equal(2, author.Books.Length); + Assert.Contains("Foundation", author.Books); + } + + [Fact] + public void FriendsOfPerson_CanStoreFriendNames() + { + var person = new FriendsOfPerson + { + Name = new Name { FirstName = "Alice", LastName = "Smith" }, + FriendNames = new[] + { + new Name { FirstName = "Bob", LastName = "Jones" }, + new Name { FirstName = "Carol", LastName = "White" } + } + }; + + Assert.Equal(2, person.FriendNames.Length); + Assert.Equal("Bob", person.FriendNames[0].FirstName); + } + + #endregion + + #region Generic Tests + + [Fact] + public void Child_Generic_CanStoreValue() + { + var child = new Child { Name = "IntChild", Value = 42 }; + Assert.Equal("IntChild", child.Name); + Assert.Equal(42, child.Value); + + var childString = new Child { Name = "StringChild", Value = "test" }; + Assert.Equal("test", childString.Value); + } + + [Fact] + public void Parent_Generic_CanStoreChildrenOfDifferentTypes() + { + var parent = new Parent + { + ChildrenX = new[] { new Child { Name = "Child1", Value = 1 } }, + ChildrenY = new[] { new Child { Name = "Child2", Value = "two" } } + }; + + Assert.Single(parent.ChildrenX); + Assert.Single(parent.ChildrenY); + Assert.Equal(1, parent.ChildrenX[0].Value); + Assert.Equal("two", parent.ChildrenY[0].Value); + } + + #endregion + + #region Polymorphic Shape Tests (NET7_0_OR_GREATER) + +#if NET7_0_OR_GREATER + [Fact] + public void Rectangle_CanBeCreated() + { + var rect = new Rectangle + { + Id = "rect1", + TopX = 10, + TopY = 20, + Height = 100, + Width = 50 + }; + + Assert.Equal("rect1", rect.Id); + Assert.Equal(10, rect.TopX); + Assert.Equal(100, rect.Height); + } + + [Fact] + public void Circle_CanBeCreated() + { + var circle = new Circle + { + Id = "circle1", + CenterX = 50, + CenterY = 50, + Radius = 25 + }; + + Assert.Equal("circle1", circle.Id); + Assert.Equal(25, circle.Radius); + } + + [Fact] + public void Drawing_CanStoreMultipleShapes() + { + var drawing = new Drawing + { + Shapes = new Shape[] + { + new Rectangle { Id = "r1", TopX = 0, TopY = 0, Height = 10, Width = 10 }, + new Circle { Id = "c1", CenterX = 5, CenterY = 5, Radius = 2 } + } + }; + + Assert.Equal(2, drawing.Shapes.Length); + } + + [Fact] + public void Drawing_GetShape_ReturnsCorrectShape() + { + var drawing = new Drawing + { + Shapes = + [ + new Rectangle { Id = "r1", TopX = 0, TopY = 0, Height = 10, Width = 10 }, + new Circle { Id = "c1", CenterX = 5, CenterY = 5, Radius = 2 }, + new Rectangle { Id = "r2", TopX = 20, TopY = 20, Height = 5, Width = 5 } + ] + }; + + var firstRect = drawing.GetShape(0); + Assert.NotNull(firstRect); + Assert.Equal("r1", firstRect.Id); + + var secondRect = drawing.GetShape(1); + Assert.NotNull(secondRect); + Assert.Equal("r2", secondRect.Id); + + var circle = drawing.GetShape(0); + Assert.NotNull(circle); + Assert.Equal("c1", circle.Id); + } + + [Fact] + public void Drawing_GetShape_ReturnsNull_WhenNotFound() + { + var drawing = new Drawing + { + Shapes = new Shape[] { new Rectangle { Id = "r1", TopX = 0, TopY = 0, Height = 10, Width = 10 } } + }; + + var circle = drawing.GetShape(0); + Assert.Null(circle); + } + + [Fact] + public void Drawing_Serialization_PreservesPolymorphicTypes() + { + var drawing = new Drawing + { + Shapes = new Shape[] + { + new Rectangle { Id = "r1", TopX = 0, TopY = 0, Height = 10, Width = 10 }, + new Circle { Id = "c1", CenterX = 5, CenterY = 5, Radius = 2 } + } + }; + + var json = System.Text.Json.JsonSerializer.Serialize(drawing); + var deserialized = System.Text.Json.JsonSerializer.Deserialize(json); + + Assert.NotNull(deserialized); + Assert.Equal(2, deserialized.Shapes.Length); + Assert.IsType(deserialized.Shapes[0]); + Assert.IsType(deserialized.Shapes[1]); + } +#endif + + #endregion +} diff --git a/dotnet/typeagent/tests/typeChat.test/Schemas.cs b/dotnet/typeagent/tests/typeChat.test/Schemas.cs index 7f0d12262..cb8054941 100644 --- a/dotnet/typeagent/tests/typeChat.test/Schemas.cs +++ b/dotnet/typeagent/tests/typeChat.test/Schemas.cs @@ -392,6 +392,8 @@ public class Drawing { return (T)shape; } + + curNumber++; } } return null; diff --git a/dotnet/typeagent/tests/typeChat.test/StringExTests.cs b/dotnet/typeagent/tests/typeChat.test/StringExTests.cs new file mode 100644 index 000000000..a84dae621 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/StringExTests.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.TypeChat.Tests; + +public class StringExTests +{ + [Fact] + public void ExtractLine_ReturnsCorrectLines() + { + // Arrange + string text = "Line1\nLine2\nLine3\nLine4\nLine5"; + var sb = new StringBuilder(); + + // Act + text.ExtractLine(2, sb); + + // Assert + var expected = "Line2\r\nLine3\r\nLine4\r\n"; + Assert.Equal(expected, sb.ToString()); + } + + [Fact] + public void ExtractLine_HandlesFirstLine() + { + string text = "A\nB\nC"; + var sb = new StringBuilder(); + + text.ExtractLine(0, sb); + + var expected = "A\r\nB\r\n"; + Assert.Equal(expected, sb.ToString()); + } + + [Fact] + public void AppendLineNotEmpty_AppendsNonEmptyLine() + { + var sb = new StringBuilder(); + sb.AppendLineNotEmpty("Test"); + + Assert.Equal("Test\r\n", sb.ToString()); + } + + [Fact] + public void AppendLineNotEmpty_DoesNotAppendEmptyLine() + { + var sb = new StringBuilder(); + sb.AppendLineNotEmpty(""); + sb.AppendLineNotEmpty(null); + + Assert.Equal(string.Empty, sb.ToString()); + } + + [Fact] + public void TrimAndAppendLine_TrimsAndAppends() + { + var sb = new StringBuilder(); + sb.TrimAndAppendLine(" Hello World "); + + Assert.Equal("Hello World\r\n", sb.ToString()); + } + + [Fact] + public void AppendMultiple_AppendsWithSeparator() + { + var sb = new StringBuilder(); + var items = new List { "A", "B", "C" }; + + sb.AppendMultiple(",", items); + + Assert.Equal("A,B,C", sb.ToString()); + } + + [Fact] + public void AppendMultiple_EmptyList() + { + var sb = new StringBuilder(); + var items = new List(); + + sb.AppendMultiple(",", items); + + Assert.Equal(string.Empty, sb.ToString()); + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/TestConfig.cs b/dotnet/typeagent/tests/typeChat.test/TestConfig.cs index 7744ebd82..54a1c3657 100644 --- a/dotnet/typeagent/tests/typeChat.test/TestConfig.cs +++ b/dotnet/typeagent/tests/typeChat.test/TestConfig.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Text.Json; + namespace Microsoft.TypeChat.Tests; public class TestConfig : TypeChatTest @@ -98,4 +100,32 @@ public void TestEnvAzure() SetEnv(OpenAIConfig.VariableNames.AZURE_OPENAI_API_KEY, prevKey); } } + + [Fact] + public void TestConfigFromFile() + { + // create a temp config file + string file = Path.GetTempFileName(); + var config = OpenAIConfig.FromEnvironment(); + var sconfig = JsonSerializer.Serialize(config); + File.WriteAllText(file, sconfig); + var fileCfg = OpenAIConfig.LoadFromJsonFile(file); + + Assert.Equal(sconfig, JsonSerializer.Serialize(fileCfg)); + } + + [Fact] + public void TestInvalidConfigFile() + { + string file = Path.GetTempFileName(); + File.AppendAllText(file, ""); + + Assert.Throws(() => + { + var config = OpenAIConfig.LoadFromJsonFile(file); + config.Validate(file); + }); + + File.Delete(file); + } } diff --git a/dotnet/typeagent/tests/typeChat.test/TextEmbeddingCachetests.cs b/dotnet/typeagent/tests/typeChat.test/TextEmbeddingCachetests.cs new file mode 100644 index 000000000..1d6d38b16 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/TextEmbeddingCachetests.cs @@ -0,0 +1,298 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using TypeAgent.Vector; + +namespace Microsoft.TypeChat.Tests; + +public class TextEmbeddingCacheTests +{ + private const int DefaultCacheSize = 10; + + private static float[] CreateTestEmbedding(int seed = 0) + { + var random = new Random(seed); + var embedding = new float[128]; + for (int i = 0; i < embedding.Length; i++) + { + embedding[i] = (float)random.NextDouble(); + } + return embedding; + } + + [Fact] + public void Constructor_SetsMemCacheSize() + { + // Arrange & Act + var cache = new TextEmbeddingCache(DefaultCacheSize); + + // Assert + Assert.NotNull(cache); + Assert.Null(cache.PersistentCache); + } + + [Fact] + public void Add_WithValidEmbedding_AddsToCache() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var key = "test_key"; + var embedding = CreateTestEmbedding(1); + + // Act + cache.Add(key, embedding); + + // Assert + var result = cache.Get(key); + Assert.NotNull(result); + Assert.Equal(embedding, result); + } + + [Fact] + public void Add_WithNullValue_DoesNotAddToCache() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var key = "test_key"; + + // Act + cache.Add(key, null); + + // Assert + Assert.Null(cache.Get(key)); + } + + [Fact] + public void Get_WithExistingKey_ReturnsEmbedding() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var key = "test_key"; + var embedding = CreateTestEmbedding(2); + cache.Add(key, embedding); + + // Act + var result = cache.Get(key); + + // Assert + Assert.NotNull(result); + Assert.Equal(embedding, result); + } + + [Fact] + public void Get_WithNonExistingKey_ReturnsNull() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + + // Act + var result = cache.Get("nonexistent_key"); + + // Assert + Assert.Null(result); + } + + [Fact] + public void TryGet_WithExistingKey_ReturnsTrueAndValue() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var key = "test_key"; + var embedding = CreateTestEmbedding(3); + cache.Add(key, embedding); + + // Act + var result = cache.TryGet(key, out var value); + + // Assert + Assert.True(result); + Assert.NotNull(value); + Assert.Equal(embedding, value); + } + + [Fact] + public void TryGet_WithNonExistingKey_ReturnsFalse() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + + // Act + var result = cache.TryGet("nonexistent_key", out var value); + + // Assert + Assert.False(result); + Assert.Null(value); + } + + [Fact] + public void TryGet_WithPersistentCache_FallsBackToPersistentCache() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var key = "test_key"; + var embedding = CreateTestEmbedding(4); + var persistentCache = new MockPersistentCache(); + persistentCache.Add(key, new Embedding(embedding)); + cache.PersistentCache = persistentCache; + + // Act + var result = cache.TryGet(key, out var value); + + // Assert + Assert.True(result); + Assert.NotNull(value); + Assert.Equal(embedding, value); + } + + [Fact] + public void TryGet_ChecksMemCacheBeforePersistentCache() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var key = "test_key"; + var memEmbedding = CreateTestEmbedding(5); + var persistentEmbedding = CreateTestEmbedding(6); + + var persistentCache = new MockPersistentCache(); + persistentCache.Add(key, new Embedding(persistentEmbedding)); + cache.PersistentCache = persistentCache; + + cache.Add(key, memEmbedding); + + // Act + var result = cache.TryGet(key, out var value); + + // Assert + Assert.True(result); + Assert.NotNull(value); + Assert.Equal(memEmbedding, value); // Should get memory cache value + } + + [Fact] + public void Add_MultipleKeys_AllStored() + { + // Arrange + var cache = new TextEmbeddingCache(3); + var keys = new[] { "key1", "key2", "key3" }; + var embeddings = keys.Select((_, i) => CreateTestEmbedding(i + 10)).ToArray(); + + // Act + for (int i = 0; i < keys.Length; i++) + { + cache.Add(keys[i], embeddings[i]); + } + + // Assert + Assert.Equal(keys.Length, cache.Count); + for (int i = 0; i < keys.Length; i++) + { + var result = cache.Get(keys[i]); + Assert.NotNull(result); + Assert.Equal(embeddings[i], result); + } + } + + [Fact] + public void Add_ExceedingCacheSize_EvictsOldEntries() + { + // Arrange + var cacheSize = 3; + var cache = new TextEmbeddingCache(cacheSize); + var keys = new[] { "key1", "key2", "key3", "key4", "key5" }; + + // Act + for (int i = 0; i < keys.Length; i++) + { + cache.Add(keys[i], CreateTestEmbedding(i + 20)); + } + + // Assert + // Cache should have evicted old entries and count should reflect the LRU high watermark + Assert.True(cache.Count <= cacheSize); + + // Most recent entries should be accessible + Assert.NotNull(cache.Get(keys[^1])); // Last key should be present + } + + [Fact] + public void PersistentCache_CanBeSetAndRetrieved() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var persistentCache = new MockPersistentCache(); + + // Act + cache.PersistentCache = persistentCache; + + // Assert + Assert.NotNull(cache.PersistentCache); + Assert.Same(persistentCache, cache.PersistentCache); + } + + [Fact] + public void Count_ReflectsHighWatermark() + { + // Arrange + var cache = new TextEmbeddingCache(3); + + // Act + cache.Add("key1", CreateTestEmbedding(30)); + cache.Add("key2", CreateTestEmbedding(31)); + cache.Add("key3", CreateTestEmbedding(32)); + cache.Add("key4", CreateTestEmbedding(33)); + + // Assert + Assert.Equal(3, cache.Count); + } + + [Fact] + public void TryGet_WithPersistentCacheReturningNull_ReturnsFalse() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var persistentCache = new MockPersistentCache(); + cache.PersistentCache = persistentCache; + + // Act + var result = cache.TryGet("nonexistent", out var value); + + // Assert + Assert.False(result); + Assert.Null(value); + } + + [Fact] + public void Add_SameKeyTwice_UpdatesValue() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + var key = "test_key"; + var embedding1 = CreateTestEmbedding(40); + var embedding2 = CreateTestEmbedding(41); + + // Act + cache.Add(key, embedding1); + cache.Add(key, embedding2); + + // Assert + var result = cache.Get(key); + Assert.NotNull(result); + Assert.Equal(embedding2, result); + } + + // Mock implementation of IReadOnlyCache for testing + private class MockPersistentCache : IReadOnlyCache + { + private readonly Dictionary _storage = new(); + + public void Add(string key, Embedding value) + { + _storage[key] = value; + } + + public bool TryGet(string key, out Embedding value) + { + return _storage.TryGetValue(key, out value); + } + } +} diff --git a/dotnet/typeagent/tests/typeChat.test/TextEmbeddingModelWithCacheTests.cs b/dotnet/typeagent/tests/typeChat.test/TextEmbeddingModelWithCacheTests.cs new file mode 100644 index 000000000..e264dd555 --- /dev/null +++ b/dotnet/typeagent/tests/typeChat.test/TextEmbeddingModelWithCacheTests.cs @@ -0,0 +1,452 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using TypeAgent.AIClient; +using TypeAgent.Vector; + +namespace Microsoft.TypeChat.Tests; + +public class TextEmbeddingModelWithCacheTests +{ + private const int DefaultCacheSize = 10; + private const int DefaultMaxBatchSize = 16; + + public TextEmbeddingModelWithCacheTests() + { + TestHelpers.LoadDotEnvOrSkipTest(); + } + + private static float[] CreateTestEmbedding(int seed = 0) + { + var random = new Random(seed); + var embedding = new float[128]; + for (int i = 0; i < embedding.Length; i++) + { + embedding[i] = (float)random.NextDouble(); + } + return embedding; + } + + private class MockTextEmbeddingModel : ITextEmbeddingModel + { + private readonly Dictionary _embeddings = []; + private int _generateCallCount = 0; + private int _generateBatchCallCount = 0; + + public int MaxBatchSize { get; } + + public int GenerateCallCount => _generateCallCount; + public int GenerateBatchCallCount => _generateBatchCallCount; + + public MockTextEmbeddingModel(int maxBatchSize = DefaultMaxBatchSize) + { + MaxBatchSize = maxBatchSize; + } + + public Task GenerateAsync(string text, CancellationToken cancellationToken) + { + _generateCallCount++; + + if (!_embeddings.TryGetValue(text, out var embedding)) + { + embedding = CreateTestEmbedding(text.GetHashCode()); + _embeddings[text] = embedding; + } + + return Task.FromResult(embedding); + } + + public Task> GenerateAsync(IList texts, CancellationToken cancellationToken) + { + _generateBatchCallCount++; + + var results = new List(); + foreach (var text in texts) + { + if (!_embeddings.TryGetValue(text, out var embedding)) + { + embedding = CreateTestEmbedding(text.GetHashCode()); + _embeddings[text] = embedding; + } + results.Add(embedding); + } + + return Task.FromResult>(results); + } + + public void Reset() + { + _generateCallCount = 0; + _generateBatchCallCount = 0; + } + } + + [Fact] + public void Constructor_WithCacheSize_CreatesInstance() + { + // Arrange & Act + var model = new TextEmbeddingModelWithCache(DefaultCacheSize); + + // Assert + Assert.NotNull(model); + Assert.NotNull(model.InnerModel); + Assert.NotNull(model.Cache); + Assert.True(model.CacheEnabled); + } + + [Fact] + public void Constructor_WithModelAndCacheSize_CreatesInstance() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + + // Act + var model = new TextEmbeddingModelWithCache(innerModel, DefaultCacheSize); + + // Assert + Assert.NotNull(model); + Assert.NotNull(model.InnerModel); + Assert.NotNull(model.Cache); + Assert.True(model.CacheEnabled); + } + + [Fact] + public void Constructor_WithModelAndCache_CreatesInstance() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + + // Act + var model = new TextEmbeddingModelWithCache(innerModel, cache); + + // Assert + Assert.NotNull(model); + Assert.Same(innerModel, model.InnerModel); + Assert.Same(cache, model.Cache); + Assert.True(model.CacheEnabled); + } + + [Fact] + public void Constructor_WithNullInnerModel_ThrowsArgumentNullException() + { + // Arrange + var cache = new TextEmbeddingCache(DefaultCacheSize); + + // Act & Assert + Assert.Throws(() => new TextEmbeddingModelWithCache(null!, cache)); + } + + [Fact] + public void Constructor_WithNullCache_ThrowsArgumentNullException() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + + // Act & Assert + Assert.Throws(() => new TextEmbeddingModelWithCache(innerModel, null!)); + } + + [Fact] + public void MaxBatchSize_ReturnsInnerModelMaxBatchSize() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(32); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + + // Act + var maxBatchSize = model.MaxBatchSize; + + // Assert + Assert.Equal(32, maxBatchSize); + } + + [Fact] + public void CacheEnabled_CanBeSetAndRetrieved() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + + // Act + model.CacheEnabled = false; + + // Assert + Assert.False(model.CacheEnabled); + } + + [Fact] + public async Task GenerateAsync_SingleText_CallsInnerModelAndCachesResult() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var text = "test text"; + + // Act + var result = await model.GenerateAsync(text, CancellationToken.None); + + // Assert + Assert.NotNull(result); + Assert.Equal(1, innerModel.GenerateCallCount); + + // Verify result is cached + var cachedResult = cache.Get(text); + Assert.NotNull(cachedResult); + Assert.Equal(result, cachedResult); + } + + [Fact] + public async Task GenerateAsync_SingleText_UsesCacheOnSecondCall() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var text = "test text"; + + // Act + var result1 = await model.GenerateAsync(text, CancellationToken.None); + innerModel.Reset(); + var result2 = await model.GenerateAsync(text, CancellationToken.None); + + // Assert + Assert.Equal(result1, result2); + Assert.Equal(0, innerModel.GenerateCallCount); + } + + [Fact] + public async Task GenerateAsync_SingleText_WithCacheDisabled_BypassesCache() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var text = "test text"; + model.CacheEnabled = false; + + // Act + var result = await model.GenerateAsync(text, CancellationToken.None); + + // Assert + Assert.NotNull(result); + Assert.Equal(1, innerModel.GenerateCallCount); + + // Verify result is NOT cached + var cachedResult = cache.Get(text); + Assert.Null(cachedResult); + } + + [Fact] + public async Task GenerateAsync_MultipleTexts_CallsInnerModelAndCachesResults() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var texts = new List { "text1", "text2", "text3" }; + + // Act + var results = await model.GenerateAsync(texts, CancellationToken.None); + + // Assert + Assert.NotNull(results); + Assert.Equal(3, results.Count); + Assert.Equal(1, innerModel.GenerateBatchCallCount); + + // Verify all results are cached + foreach (var text in texts) + { + var cachedResult = cache.Get(text); + Assert.NotNull(cachedResult); + } + } + + [Fact] + public async Task GenerateAsync_MultipleTexts_UsesCacheForCachedTexts() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var texts1 = new List { "text1", "text2" }; + var texts2 = new List { "text2", "text3" }; // text2 is already cached + + // Act + await model.GenerateAsync(texts1, CancellationToken.None); + innerModel.Reset(); + var results = await model.GenerateAsync(texts2, CancellationToken.None); + + // Assert + Assert.NotNull(results); + Assert.Equal(2, results.Count); + + // Only text3 should cause a call to inner model + Assert.Equal(1, innerModel.GenerateBatchCallCount); + } + + [Fact] + public async Task GenerateAsync_MultipleTexts_WithCacheDisabled_BypassesCache() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var texts = new List { "text1", "text2", "text3" }; + model.CacheEnabled = false; + + // Act + var results = await model.GenerateAsync(texts, CancellationToken.None); + + // Assert + Assert.NotNull(results); + Assert.Equal(3, results.Count); + Assert.Equal(1, innerModel.GenerateBatchCallCount); + + // Verify results are NOT cached + foreach (var text in texts) + { + var cachedResult = cache.Get(text); + Assert.Null(cachedResult); + } + } + + [Fact] + public async Task GenerateAsync_MultipleTexts_AllCached_DoesNotCallInnerModel() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var texts = new List { "text1", "text2", "text3" }; + + // Pre-populate cache + await model.GenerateAsync(texts, CancellationToken.None); + innerModel.Reset(); + + // Act + var results = await model.GenerateAsync(texts, CancellationToken.None); + + // Assert + Assert.NotNull(results); + Assert.Equal(3, results.Count); + Assert.Equal(0, innerModel.GenerateBatchCallCount); + } + + [Fact] + public async Task GenerateAsync_MixedCachedAndUncached_OptimizesInnerModelCalls() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + + // Pre-cache some texts + await model.GenerateAsync(new List { "text1", "text2" }, CancellationToken.None); + innerModel.Reset(); + + var mixedTexts = new List { "text1", "text3", "text2", "text4" }; + + // Act + var results = await model.GenerateAsync(mixedTexts, CancellationToken.None); + + // Assert + Assert.NotNull(results); + Assert.Equal(4, results.Count); + + // Should only call inner model for text3 and text4 + Assert.Equal(1, innerModel.GenerateBatchCallCount); + } + + [Fact] + public async Task GenerateAsync_WithCancellationToken_PassesToInnerModel() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var cts = new CancellationTokenSource(); + var text = "test text"; + + // Act + var result = await model.GenerateAsync(text, cts.Token); + + // Assert + Assert.NotNull(result); + } + + [Fact] + public async Task GenerateAsync_CacheToggle_WorksCorrectly() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var text = "test text"; + + // Act & Assert - With cache enabled + var result1 = await model.GenerateAsync(text, CancellationToken.None); + Assert.Equal(1, innerModel.GenerateCallCount); + + // Second call should use cache + innerModel.Reset(); + var result2 = await model.GenerateAsync(text, CancellationToken.None); + Assert.Equal(0, innerModel.GenerateCallCount); + Assert.Equal(result1, result2); + + // Disable cache and call again + model.CacheEnabled = false; + innerModel.Reset(); + var result3 = await model.GenerateAsync(text, CancellationToken.None); + Assert.Equal(1, innerModel.GenerateCallCount); + + // Enable cache again + model.CacheEnabled = true; + innerModel.Reset(); + var result4 = await model.GenerateAsync(text, CancellationToken.None); + Assert.Equal(0, innerModel.GenerateCallCount); // Should use cache + } + + [Fact] + public async Task GenerateAsync_SameTextMultipleTimes_UsesCache() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(DefaultCacheSize); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var text = "test text"; + + // Act + var result1 = await model.GenerateAsync(text, CancellationToken.None); + var result2 = await model.GenerateAsync(text, CancellationToken.None); + var result3 = await model.GenerateAsync(text, CancellationToken.None); + + // Assert + Assert.Equal(result1, result2); + Assert.Equal(result2, result3); + Assert.Equal(1, innerModel.GenerateCallCount); // Only called once + } + + [Fact] + public async Task GenerateAsync_LargeNumberOfTexts_HandlesCachingCorrectly() + { + // Arrange + var innerModel = new MockTextEmbeddingModel(); + var cache = new TextEmbeddingCache(100); + var model = new TextEmbeddingModelWithCache(innerModel, cache); + var texts = Enumerable.Range(0, 50).Select(i => $"text{i}").ToList(); + + // Act + var results1 = await model.GenerateAsync(texts, CancellationToken.None); + innerModel.Reset(); + var results2 = await model.GenerateAsync(texts, CancellationToken.None); + + // Assert + Assert.Equal(50, results1.Count); + Assert.Equal(50, results2.Count); + Assert.Equal(0, innerModel.GenerateBatchCallCount); // All cached + } +} diff --git a/python/fineTuning/unsloth/batchInfer.py b/python/fineTuning/unsloth/batchInfer.py index 7b7a20899..247b451cb 100644 --- a/python/fineTuning/unsloth/batchInfer.py +++ b/python/fineTuning/unsloth/batchInfer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation and Henry Lucco. +# Licensed under the MIT License. + from unsloth import FastLanguageModel import sys import os diff --git a/python/fineTuning/unsloth/knowledgePrompt.py b/python/fineTuning/unsloth/knowledgePrompt.py index 393ac72f7..88c0de90d 100644 --- a/python/fineTuning/unsloth/knowledgePrompt.py +++ b/python/fineTuning/unsloth/knowledgePrompt.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation and Henry Lucco. +# Licensed under the MIT License. + def get_knowledge_prompt(message: str) -> str: template = "You are a service that translates user messages in a conversation into JSON objects of type \"KnowledgeResponse\" according to the following TypeScript definitions:\n```\n\n\nexport type Quantity = {\n amount: number;\n units: string;\n};\nexport type Value = string | number | boolean | Quantity;\nexport type Facet = {\n name: string;\n // Very concise values.\n value: Value;\n};\n// Specific, tangible people, places, institutions or things only\nexport type ConcreteEntity = {\n // the name of the entity or thing such as \"Bach\", \"Great Gatsby\", \"frog\" or \"piano\"\n name: string;\n // the types of the entity such as \"speaker\", \"person\", \"artist\", \"animal\", \"object\", \"instrument\", \"school\", \"room\", \"museum\", \"food\" etc.\n // An entity can have multiple types; entity types should be single words\n type: string[];\n // A specific, inherent, defining, or non-immediate facet of the entity such as \"blue\", \"old\", \"famous\", \"sister\", \"aunt_of\", \"weight: 4 kg\"\n // trivial actions or state changes are not facets\n // facets are concise \"properties\"\n facets?: Facet[];\n};\nexport type ActionParam = {\n name: string;\n value: Value;\n};\nexport type VerbTense = \"past\" | \"present\" | \"future\";\nexport type Action = {\n // Each verb is typically a word\n verbs: string[];\n verbTense: VerbTense;\n subjectEntityName: string;\n objectEntityName?: string;\n indirectObjectEntityName?: string;\n params?: (string | ActionParam)[];\n // If the action implies this additional facet or property of the subjectEntity, such as hobbies, activities, interests, personality\n subjectEntityFacet?: Facet | undefined;\n};\n// Detailed and comprehensive knowledge response\nexport type KnowledgeResponse = {\n entities: ConcreteEntity[];\n // The 'subjectEntityName' and 'objectEntityName' must correspond to the 'name' of an entity listed in the 'entities' array.\n actions: Action[];\n // Some actions can ALSO be expressed in a reverse way... e.g. (A give to B) --> (B receive from A) and vice versa\n // If so, also return the reverse form of the action, full filled out\n inverseActions: Action[];\n // Detailed, descriptive topics and keyword.\n topics: string[];\n};\n\n```\nThe following are messages in a conversation:\n\"\"\"\n<>\n\"\"\"\nThe following is the user request translated into a JSON object with no spaces and no properties with the value undefined:\n" return template.replace("<>", message) diff --git a/python/fineTuning/unsloth/nltkExtract.py b/python/fineTuning/unsloth/nltkExtract.py index 3fe5fdbed..b086c927b 100644 --- a/python/fineTuning/unsloth/nltkExtract.py +++ b/python/fineTuning/unsloth/nltkExtract.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation and Henry Lucco. +# Licensed under the MIT License. + from rake_nltk import Rake import sys import os diff --git a/python/fineTuning/unsloth/trainEntities.py b/python/fineTuning/unsloth/trainEntities.py index 9cd585dc7..f409d2b86 100644 --- a/python/fineTuning/unsloth/trainEntities.py +++ b/python/fineTuning/unsloth/trainEntities.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation and Henry Lucco. +# Licensed under the MIT License. + from unsloth import FastLanguageModel import torch from knowledgePrompt import get_knowledge_prompt