diff --git a/dotnet/typeagent/examples/examplesLib/KnowProWriter.cs b/dotnet/typeagent/examples/examplesLib/KnowProWriter.cs index 17c339ae5..ea274a5f3 100644 --- a/dotnet/typeagent/examples/examplesLib/KnowProWriter.cs +++ b/dotnet/typeagent/examples/examplesLib/KnowProWriter.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Text.RegularExpressions; using System.Threading.Tasks; using TypeAgent.ExamplesLib.CommandLine; using TypeAgent.KnowPro; @@ -41,6 +42,7 @@ public static async Task WriteMessagesAsync(IConversation conversation) { await foreach (var message in conversation.Messages) { + WriteLine($"[{message.MessageId}]"); WriteMessage(message); WriteLine(); } diff --git a/dotnet/typeagent/examples/knowProConsole/MemoryCommands.cs b/dotnet/typeagent/examples/knowProConsole/MemoryCommands.cs index 9b8988c39..cce48718d 100644 --- a/dotnet/typeagent/examples/knowProConsole/MemoryCommands.cs +++ b/dotnet/typeagent/examples/knowProConsole/MemoryCommands.cs @@ -2,6 +2,10 @@ // Licensed under the MIT License. +using TypeAgent.ExamplesLib.CommandLine; +using TypeAgent.KnowPro; +using TypeAgent.KnowPro.Lang; + namespace KnowProConsole; public class MemoryCommands : ICommandModule @@ -20,7 +24,8 @@ public IList GetCommands() MessagesDef(), AliasesDef(), SearchDef(), - SearchRagDef() + SearchRagDef(), + AnswerRagDef() ]; } @@ -125,7 +130,7 @@ private Task SearchAsync(ParseResult args, CancellationToken cancellationToken) private Command SearchRagDef() { - Command command = new("kpSearchRag") + Command command = new("kpSearchRag", "Text similarity search.") { Options.Arg("query"), Options.Arg("maxMatches", 25), @@ -135,6 +140,7 @@ private Command SearchRagDef() command.SetAction(this.SearchRagAsync); return command; } + private async Task SearchRagAsync(ParseResult args, CancellationToken cancellationToken) { IConversation conversation = EnsureConversation(); @@ -167,4 +173,53 @@ private IConversation EnsureConversation() ? _kpContext.Conversation! : throw new InvalidOperationException("No conversation loaded"); } + + private Command AnswerRagDef() + { + Command command = new("kpAnswerRag", "Answer using classic RAG.") + { + Options.Arg("query"), + Options.Arg("maxMatches", 25), + Options.Arg("minScore", 0.7), + Options.Arg("budget", 16 * 1024) + }; + command.SetAction(this.AnswerRagAsync); + return command; + } + + private async Task AnswerRagAsync(ParseResult args, CancellationToken cancellationToken) + { + + IConversation conversation = EnsureConversation(); + + NamedArgs namedArgs = new NamedArgs(args); + string? query = namedArgs.Get("query"); + if (string.IsNullOrEmpty(query)) + { + return; + } + + AnswerResponse answer = await conversation.AnswerQuestionRagAsync( + query, + namedArgs.Get("minScore"), + namedArgs.Get("budget"), + new AnswerContextOptions() + { + MessagesTopK = namedArgs.Get("maxMatches"), + }, + null, + cancellationToken + ).ConfigureAwait(false); + + KnowProWriter.WriteLine(); + if (answer.Type == AnswerType.Answered) + { + KnowProWriter.WriteLine(ConsoleColor.Green, answer.Answer); + } + else + { + KnowProWriter.WriteLine(ConsoleColor.Yellow, answer.WhyNoAnswer); + } + KnowProWriter.WriteLine(); + } } diff --git a/dotnet/typeagent/src/knowpro/Answer/AnswerContextBuilder.cs b/dotnet/typeagent/src/knowpro/Answer/AnswerContextBuilder.cs index db745fb0b..8e16b9c01 100644 --- a/dotnet/typeagent/src/knowpro/Answer/AnswerContextBuilder.cs +++ b/dotnet/typeagent/src/knowpro/Answer/AnswerContextBuilder.cs @@ -184,7 +184,7 @@ public async ValueTask> GetRelevantMessagesAsync( { return []; } - List ordinals = messageMatches.ToMessageOrdinals(topK); + List ordinals = messageMatches.ToMessageOrdinals(topK, true); IList messages = await _conversation.GetMessageReader().GetAsync( ordinals, cancellationToken diff --git a/dotnet/typeagent/src/knowpro/ConversationAnswer.cs b/dotnet/typeagent/src/knowpro/ConversationAnswer.cs index 12acacf54..f526d9b2b 100644 --- a/dotnet/typeagent/src/knowpro/ConversationAnswer.cs +++ b/dotnet/typeagent/src/knowpro/ConversationAnswer.cs @@ -38,7 +38,6 @@ public static async ValueTask AnswerQuestionAsync( throw new NotImplementedException("Answer chunking"); } - public static async ValueTask AnswerQuestionAsync( this IConversation conversation, string question, @@ -113,4 +112,46 @@ public static async ValueTask AnswerQuestionAsync( ).ConfigureAwait(false); return combinedResponse; } + + /// + /// Performs answer generation using RAG search and RAG context for answer generation + /// + /// The conversation to use as the context for the question being asked. + /// The question being asked. + /// A progresss callback. + /// The cancellation token to abort if necessary. + /// + public static async ValueTask AnswerQuestionRagAsync( + this IConversation conversation, + string question, + double minScore, + int maxCharsInBudget, + AnswerContextOptions? options, + Action? progress = null, + CancellationToken cancellationToken = default + ) + { + ConversationSearchResult searchResults = await conversation.SearchRagAsync( + question, + options.MessagesTopK, + minScore, + maxCharsInBudget, + cancellationToken + ).ConfigureAwait(false); + + if (searchResults is null) + { + return AnswerResponse.NoAnswer(); + } + + IAnswerGenerator generator = conversation.Settings.AnswerGenerator; + AnswerResponse answerResponse = await conversation.AnswerQuestionAsync( + question, + searchResults, + options, + cancellationToken + ).ConfigureAwait(false); + + return answerResponse; + } } diff --git a/dotnet/typeagent/src/knowpro/ConversationSearch.cs b/dotnet/typeagent/src/knowpro/ConversationSearch.cs index bc16db5f8..0a63a673f 100644 --- a/dotnet/typeagent/src/knowpro/ConversationSearch.cs +++ b/dotnet/typeagent/src/knowpro/ConversationSearch.cs @@ -212,7 +212,8 @@ public static async ValueTask> RunQueryAsync( if (maxCharsInBudget is not null) { - var messageOrdinals = messageMatches.ToMessageOrdinals(); + // reverse the message matches so we start with the highest ranked results first + var messageOrdinals = messageMatches.ToMessageOrdinals(null, true); int messageCountInBudget = await conversation.Messages.GetCountInCharBudgetAsync( messageOrdinals, diff --git a/dotnet/typeagent/src/knowpro/IMessage.cs b/dotnet/typeagent/src/knowpro/IMessage.cs index c0ab3546f..23408ed00 100644 --- a/dotnet/typeagent/src/knowpro/IMessage.cs +++ b/dotnet/typeagent/src/knowpro/IMessage.cs @@ -5,6 +5,7 @@ namespace TypeAgent.KnowPro; public interface IMessage : IKnowledgeSource { + int MessageId { get; set; } IList TextChunks { get; set; } IList? Tags { get; set; } diff --git a/dotnet/typeagent/src/knowpro/Message.cs b/dotnet/typeagent/src/knowpro/Message.cs index 3a6a8fc34..6f7e013d6 100644 --- a/dotnet/typeagent/src/knowpro/Message.cs +++ b/dotnet/typeagent/src/knowpro/Message.cs @@ -5,6 +5,7 @@ namespace TypeAgent.KnowPro; public class Message : IMessage { + public int MessageId { get; set; } public IList TextChunks { get; set; } public IList? Tags { get; set; } public string? Timestamp { get; set; } @@ -32,6 +33,9 @@ public Message(string text, TMeta meta) Metadata = meta; } + [JsonPropertyName("messageId")] + public int MessageId { get; set; } + [JsonPropertyName("textChunks")] public IList TextChunks { get; set; } = []; diff --git a/dotnet/typeagent/src/knowpro/MessageExtensions.cs b/dotnet/typeagent/src/knowpro/MessageExtensions.cs index 2355fd770..df913ae61 100644 --- a/dotnet/typeagent/src/knowpro/MessageExtensions.cs +++ b/dotnet/typeagent/src/knowpro/MessageExtensions.cs @@ -24,12 +24,17 @@ public static int GetCharCount(this IMessage message) public static List ToMessageOrdinals( this IList scoredOrdinals, - int? topK = null + int? topK = null, + bool reverse = false ) { return topK is not null && topK.Value < scoredOrdinals.Count - ? scoredOrdinals.Take(topK.Value).Map((s) => s.MessageOrdinal) - : scoredOrdinals.Map((s) => s.MessageOrdinal); + ? reverse + ? scoredOrdinals.Reverse().Take(topK.Value).Map((s) => s.MessageOrdinal) + : scoredOrdinals.Take(topK.Value).Map((s) => s.MessageOrdinal) + : reverse + ? scoredOrdinals.Reverse().Map((s) => s.MessageOrdinal) + : scoredOrdinals.Map((s) => s.MessageOrdinal); } public static IEnumerable AsMessageOrdinals(this IEnumerable scoredOrdinals) diff --git a/dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageCollection.cs b/dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageCollection.cs index ef8e0b560..492bdbf76 100644 --- a/dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageCollection.cs +++ b/dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageCollection.cs @@ -186,6 +186,7 @@ TMessage FromMessageRow(MessageRow messageRow) internal class MessageRow { + public int MessageId { get; set; } public string? ChunksJson { get; set; } public string? ChunkUri { get; set; } public int MessageLength { get; set; } @@ -196,6 +197,7 @@ internal class MessageRow public MessageRow Read(SqliteDataReader reader, int iCol = 0) { + MessageId = reader.GetInt32(iCol++); ChunksJson = reader.GetStringOrNull(iCol++); ChunkUri = reader.GetStringOrNull(iCol++); MessageLength = reader.GetInt32(iCol++); @@ -341,6 +343,7 @@ IMessage FromMessageRow(MessageRow messageRow) { IMessage message = (IMessage)Activator.CreateInstance(_messageType); + message.MessageId = messageRow.MessageId; message.TextChunks = StorageSerializer.FromJsonArray(messageRow.ChunksJson); message.Tags = StorageSerializer.FromJsonArray(messageRow.TagsJson); message.Timestamp = messageRow.StartTimestamp; @@ -375,7 +378,7 @@ public static MessageRow GetMessage(SqliteDatabase db, int msgId) KnowProVerify.ThrowIfInvalidMessageOrdinal(msgId); return db.Get(@" -SELECT chunks, chunk_uri, message_length, start_timestamp, tags, metadata, extra +SELECT msg_id, chunks, chunk_uri, message_length, start_timestamp, tags, metadata, extra FROM Messages WHERE msg_id = @msg_id", (cmd) => { @@ -399,7 +402,7 @@ public static IEnumerable GetMessages(SqliteDatabase db, IList FROM Messages WHERE msg_id IN ({SqliteDatabase.MakeInStatement(placeholderIds)}) ", (cmd) => cmd.AddPlaceholderParameters(placeholderIds, batch), - (reader) => new KeyValuePair(reader.GetInt32(0), ReadMessageRow(reader, 1)), + (reader) => new KeyValuePair(reader.GetInt32(0), ReadMessageRow(reader, 0)), messageRows ); } @@ -463,7 +466,7 @@ SELECT start_timestamp public static IAsyncEnumerable GetAllMessagesAsync(SqliteDatabase db, CancellationToken cancellation = default) { return db.EnumerateAsync(@" -SELECT chunks, chunk_uri, message_length, start_timestamp, tags, metadata, extra +SELECT msg_id, chunks, chunk_uri, message_length, start_timestamp, tags, metadata, extra FROM Messages ORDER BY msg_id", ReadMessageRow, cancellation