diff --git a/CLAUDE.md b/CLAUDE.md
index 1fb7d115..c5feb925 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -17,7 +17,17 @@ The repository contains separate packages for R and Python:
```
/
├── pkg-r/ # R package implementation
-│ ├── R/ # R source files
+│ ├── R/ # R source files (R6 classes and utilities)
+│ │ ├── QueryChat.R # Main QueryChat R6 class
+│ │ ├── DataSource.R # Abstract DataSource base class
+│ │ ├── DataFrameSource.R # DataSource for data.frames
+│ │ ├── DBISource.R # DataSource for DBI connections
+│ │ ├── TblSqlSource.R # DataSource for dbplyr tbl_sql
+│ │ ├── QueryChatSystemPrompt.R # System prompt management (internal)
+│ │ ├── querychat_module.R # Shiny module functions (internal)
+│ │ ├── querychat_tools.R # Tool definitions for LLM
+│ │ ├── deprecated.R # Deprecated functional API
+│ │ └── utils-*.R # Utility functions
│ ├── inst/ # Installed files
│ │ ├── examples-shiny/ # Shiny example applications
│ │ ├── htmldep/ # HTML dependencies
@@ -98,19 +108,33 @@ make py-docs-preview
### Core Components
+Both R and Python implementations use an object-oriented architecture:
+
1. **Data Sources**: Abstractions for data frames and database connections that provide schema information and execute SQL queries
- - R: `querychat_data_source()` in `pkg-r/R/data_source.R`
+ - R: R6 class hierarchy in `pkg-r/R/`
+ - `DataSource` - Abstract base class defining the interface (`DataSource.R`)
+ - `DataFrameSource` - For data.frame objects (`DataFrameSource.R`)
+ - `DBISource` - For DBI database connections (`DBISource.R`)
+ - `TblSqlSource` - For dbplyr tbl_sql objects (`TblSqlSource.R`)
- Python: `DataSource` classes in `pkg-py/src/querychat/datasource.py`
2. **LLM Client**: Integration with LLM providers (OpenAI, Anthropic, etc.) through:
- R: ellmer package
- Python: chatlas package
-3. **Query Chat Interface**: UI components and server logic for the chat experience:
- - R: `querychat_sidebar()`, `querychat_ui()`, and `querychat_server()` in `pkg-r/R/querychat.R`
+3. **Query Chat Interface**: Main orchestration class that manages the chat experience:
+ - R: `QueryChat` R6 class in `pkg-r/R/QueryChat.R`
+ - Provides methods: `$new()`, `$app()`, `$sidebar()`, `$ui()`, `$server()`, `$df()`, `$sql()`, etc.
+ - Internal Shiny module functions: `mod_ui()` and `mod_server()` in `pkg-r/R/querychat_module.R`
- Python: `QueryChat` class in `pkg-py/src/querychat/querychat.py`
-4. **Prompt Engineering**: System prompts and tool definitions that guide the LLM:
+4. **System Prompt Management**:
+ - R: `QueryChatSystemPrompt` R6 class in `pkg-r/R/QueryChatSystemPrompt.R`
+ - Handles loading and rendering of prompt templates with Mustache
+ - Manages data descriptions and extra instructions
+ - Python: Similar logic in `QueryChat` class
+
+5. **Prompt Engineering**: System prompts and tool definitions that guide the LLM:
- R: `pkg-r/inst/prompts/`
- Main prompt (`prompt.md`)
- Tool descriptions (`tool-query.md`, `tool-reset-dashboard.md`, `tool-update-dashboard.md`)
@@ -118,6 +142,26 @@ make py-docs-preview
- Main prompt (`prompt.md`)
- Tool descriptions (`tool-query.md`, `tool-reset-dashboard.md`, `tool-update-dashboard.md`)
+### R Package Architecture
+
+The R package uses R6 classes for object-oriented design:
+
+- **QueryChat**: Main user-facing class that orchestrates the entire query chat experience
+ - Takes data sources as input
+ - Provides methods for UI generation (`$sidebar()`, `$ui()`, `$app()`)
+ - Manages server logic and reactive values (`$server()`)
+ - Exposes reactive accessors (`$df()`, `$sql()`, `$title()`)
+
+- **DataSource hierarchy**: Abstract interface for different data backends
+ - All implementations provide: `get_schema()`, `execute_query()`, `test_query()`, `get_data()`
+ - Allows QueryChat to work with data.frames, DBI connections, and dbplyr objects uniformly
+
+- **QueryChatSystemPrompt**: Internal class for prompt template management
+ - Loads templates from files or strings
+ - Renders prompts with tool configurations using Mustache
+
+The package has deprecated the old functional API (`querychat_init()`, `querychat_server()`, etc.) in favor of the R6 class approach. See `pkg-r/R/deprecated.R` for migration guidance.
+
### Data Flow
1. User enters a natural language query in the UI
diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py
index 81936fb9..c4ddc672 100644
--- a/pkg-py/src/querychat/_system_prompt.py
+++ b/pkg-py/src/querychat/_system_prompt.py
@@ -75,6 +75,7 @@ def render(self, tools: tuple[TOOL_GROUPS, ...] | None) -> str:
"extra_instructions": self.extra_instructions,
"has_tool_update": "update" in tools if tools else False,
"has_tool_query": "query" in tools if tools else False,
+ "include_query_guidelines": len(tools or ()) > 0,
}
return chevron.render(self.template, context)
diff --git a/pkg-py/src/querychat/prompts/prompt.md b/pkg-py/src/querychat/prompts/prompt.md
index 6a92101c..7c8ea5a1 100644
--- a/pkg-py/src/querychat/prompts/prompt.md
+++ b/pkg-py/src/querychat/prompts/prompt.md
@@ -16,6 +16,44 @@ Here is additional information about the data:
For security reasons, you may only query this specific table.
+{{#include_query_guidelines}}
+## SQL Query Guidelines
+
+When writing SQL queries to interact with the database, please adhere to the following guidelines to ensure compatibility and correctness.
+
+### Structural Rules
+
+**No trailing semicolons**
+Never end your query with a semicolon (`;`). The parent query needs to continue after your subquery closes.
+
+**Single statement only**
+Return exactly one `SELECT` statement. Do not include multiple statements separated by semicolons.
+
+**No procedural or meta statements**
+Do not include:
+- `EXPLAIN` / `EXPLAIN ANALYZE`
+- `SET` statements
+- Variable declarations
+- Transaction controls (`BEGIN`, `COMMIT`, `ROLLBACK`)
+- DDL statements (`CREATE`, `ALTER`, `DROP`)
+- `INTO` clauses (e.g., `SELECT INTO`)
+- Locking hints (`FOR UPDATE`, `FOR SHARE`)
+
+### Column Naming Rules
+
+**Alias all computed/derived columns**
+Every expression that isn't a simple column reference must have an explicit alias.
+
+**Ensure unique column names**
+The result set must not have duplicate column names, even when selecting from multiple tables.
+
+**Avoid `SELECT *` with JOINs**
+Explicitly list columns to prevent duplicate column names and ensure a predictable output schema.
+
+**Avoid reserved words as unquoted aliases**
+If using reserved words as column aliases, quote them appropriately for your dialect.
+
+{{/include_query_guidelines}}
{{#is_duck_db}}
### DuckDB SQL Tips
@@ -130,7 +168,7 @@ You might want to explore the advanced features
- The user has asked a very specific question requiring only a direct answer
- The conversation is clearly wrapping up
-#### Guidelines
+#### Suggestion Guidelines
- Suggestions can appear **anywhere** in your response—not just at the end
- Use list format at the end for 2-4 follow-up options (most common pattern)
@@ -141,7 +179,6 @@ You might want to explore the advanced features
- Never use generic phrases like "If you'd like to..." or "Would you like to explore..." — instead, provide concrete suggestions
- Never refer to suggestions as "prompts" – call them "suggestions" or "ideas" or similar
-
## Important Guidelines
- **Ask for clarification** if any request is unclear or ambiguous
diff --git a/pkg-py/tests/test_system_prompt.py b/pkg-py/tests/test_system_prompt.py
index 8df49df9..32ee5545 100644
--- a/pkg-py/tests/test_system_prompt.py
+++ b/pkg-py/tests/test_system_prompt.py
@@ -30,7 +30,9 @@ def sample_prompt_template():
{{#data_description}}Data: {{data_description}}{{/data_description}}
{{#extra_instructions}}Instructions: {{extra_instructions}}{{/extra_instructions}}
{{#has_tool_update}}UPDATE TOOL ENABLED{{/has_tool_update}}
-{{#has_tool_query}}QUERY TOOL ENABLED{{/has_tool_query}}"""
+{{#has_tool_query}}QUERY TOOL ENABLED{{/has_tool_query}}
+{{#include_query_guidelines}}QUERY GUIDELINES{{/include_query_guidelines}}
+"""
class TestQueryChatSystemPromptInit:
@@ -157,6 +159,7 @@ def test_render_with_both_tools(self, sample_data_source, sample_prompt_template
assert "UPDATE TOOL ENABLED" in rendered
assert "QUERY TOOL ENABLED" in rendered
+ assert "QUERY GUIDELINES" in rendered
assert "Database Type:" in rendered
assert "Schema:" in rendered
@@ -171,6 +174,7 @@ def test_render_with_query_only(self, sample_data_source, sample_prompt_template
assert "UPDATE TOOL ENABLED" not in rendered
assert "QUERY TOOL ENABLED" in rendered
+ assert "QUERY GUIDELINES" in rendered
def test_render_with_update_only(self, sample_data_source, sample_prompt_template):
"""Test rendering with only update tool enabled."""
@@ -183,6 +187,7 @@ def test_render_with_update_only(self, sample_data_source, sample_prompt_templat
assert "UPDATE TOOL ENABLED" in rendered
assert "QUERY TOOL ENABLED" not in rendered
+ assert "QUERY GUIDELINES" in rendered
def test_render_with_no_tools(self, sample_data_source, sample_prompt_template):
"""Test rendering with no tools enabled."""
@@ -195,6 +200,7 @@ def test_render_with_no_tools(self, sample_data_source, sample_prompt_template):
assert "UPDATE TOOL ENABLED" not in rendered
assert "QUERY TOOL ENABLED" not in rendered
+ assert "QUERY GUIDELINES" not in rendered
def test_render_includes_data_description(
self, sample_data_source, sample_prompt_template
diff --git a/pkg-r/DESCRIPTION b/pkg-r/DESCRIPTION
index d9637ea5..fd1ff3e1 100644
--- a/pkg-r/DESCRIPTION
+++ b/pkg-r/DESCRIPTION
@@ -37,6 +37,8 @@ Imports:
whisker
Suggests:
bsicons,
+ dbplyr,
+ dplyr,
DT,
duckdb,
knitr,
diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE
index 4765fced..080bf698 100644
--- a/pkg-r/NAMESPACE
+++ b/pkg-r/NAMESPACE
@@ -4,6 +4,7 @@ export(DBISource)
export(DataFrameSource)
export(DataSource)
export(QueryChat)
+export(TblSqlSource)
export(querychat)
export(querychat_app)
export(querychat_data_source)
diff --git a/pkg-r/R/DBISource.R b/pkg-r/R/DBISource.R
new file mode 100644
index 00000000..9e360197
--- /dev/null
+++ b/pkg-r/R/DBISource.R
@@ -0,0 +1,388 @@
+#' DBI Source
+#'
+#' @description
+#' A DataSource implementation for DBI database connections (SQLite, PostgreSQL,
+#' MySQL, etc.). This class wraps a DBI connection and provides SQL query
+#' execution against a single table in the database.
+#'
+#' @export
+#' @examples
+#' \dontrun{
+#' # Connect to a database
+#' con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
+#' DBI::dbWriteTable(con, "mtcars", mtcars)
+#'
+#' # Create a DBI source
+#' db_source <- DBISource$new(con, "mtcars")
+#'
+#' # Get database type
+#' db_source$get_db_type() # Returns "SQLite"
+#'
+#' # Execute a query
+#' result <- db_source$execute_query("SELECT * FROM mtcars WHERE mpg > 25")
+#'
+#' # Note: cleanup() will disconnect the connection
+#' # If you want to keep the connection open, don't call cleanup()
+#' }
+DBISource <- R6::R6Class(
+ "DBISource",
+ inherit = DataSource,
+ private = list(
+ conn = NULL
+ ),
+ public = list(
+ #' @description
+ #' Create a new DBISource
+ #'
+ #' @param conn A DBI connection object
+ #' @param table_name Name of the table in the database. Can be a character
+ #' string or a [DBI::Id()] object for tables in catalogs/schemas
+ #' @return A new DBISource object
+ #' @examples
+ #' \dontrun{
+ #' con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
+ #' DBI::dbWriteTable(con, "iris", iris)
+ #' source <- DBISource$new(con, "iris")
+ #' }
+ initialize = function(conn, table_name) {
+ if (!inherits(conn, "DBIConnection")) {
+ cli::cli_abort(
+ "{.arg conn} must be a {.cls DBIConnection}, not {.obj_type_friendly {conn}}"
+ )
+ }
+
+ # Validate table_name type
+ if (inherits(table_name, "Id")) {
+ # DBI::Id object - keep as is
+ } else if (is.character(table_name) && length(table_name) == 1) {
+ # Character string - keep as is
+ } else {
+ cli::cli_abort(
+ "{.arg table_name} must be a single character string or a {.fn DBI::Id} object"
+ )
+ }
+
+ # Check if table exists
+ if (!DBI::dbExistsTable(conn, table_name)) {
+ cli::cli_abort(c(
+ "Table {.val {DBI::dbQuoteIdentifier(conn, table_name)}} not found in database",
+ "i" = "If you're using a table in a catalog or schema, pass a {.fn DBI::Id} object to {.arg table_name}"
+ ))
+ }
+
+ private$conn <- conn
+ self$table_name <- table_name
+
+ # Store original column names for validation
+ private$colnames <- colnames(DBI::dbGetQuery(
+ conn,
+ sprintf(
+ "SELECT * FROM %s LIMIT 0",
+ DBI::dbQuoteIdentifier(conn, table_name)
+ )
+ ))
+ },
+
+ #' @description Get the database type
+ #' @return A string identifying the database type
+ get_db_type = function() {
+ # Special handling for known database types
+ if (inherits(private$conn, "duckdb_connection")) {
+ return("DuckDB")
+ }
+ if (inherits(private$conn, "SQLiteConnection")) {
+ return("SQLite")
+ }
+
+ # Default to 'POSIX' if dbms name not found
+ conn_info <- DBI::dbGetInfo(private$conn)
+ dbms_name <- getElement(conn_info, "dbms.name") %||% "POSIX"
+
+ # Remove ' SQL', if exists (SQL is already in the prompt)
+ gsub(" SQL", "", dbms_name)
+ },
+
+ #' @description
+ #' Get schema information for the database table
+ #'
+ #' @param categorical_threshold Maximum number of unique values for a text
+ #' column to be considered categorical (default: 20)
+ #' @return A string describing the schema
+ get_schema = function(categorical_threshold = 20) {
+ check_number_whole(categorical_threshold, min = 1)
+ get_schema_impl(private$conn, self$table_name, categorical_threshold)
+ },
+
+ #' @description
+ #' Execute a SQL query
+ #'
+ #' @param query SQL query string. If NULL or empty, returns all data
+ #' @return A data frame with query results
+ execute_query = function(query) {
+ check_string(query, allow_null = TRUE, allow_empty = TRUE)
+ if (is.null(query) || !nzchar(query)) {
+ query <- paste0(
+ "SELECT * FROM ",
+ DBI::dbQuoteIdentifier(private$conn, self$table_name)
+ )
+ }
+
+ check_query(query)
+ DBI::dbGetQuery(private$conn, query)
+ },
+
+ #' @description
+ #' Test a SQL query by fetching only one row
+ #'
+ #' @param query SQL query string
+ #' @param require_all_columns If TRUE, validates that the result includes
+ #' all original table columns (default: FALSE)
+ #' @return A data frame with one row of results
+ test_query = function(query, require_all_columns = FALSE) {
+ check_string(query)
+ check_bool(require_all_columns)
+ check_query(query)
+
+ rs <- DBI::dbSendQuery(private$conn, query)
+ df <- DBI::dbFetch(rs, n = 1)
+ DBI::dbClearResult(rs)
+
+ if (require_all_columns) {
+ result_columns <- names(df)
+ missing_columns <- setdiff(private$colnames, result_columns)
+
+ if (length(missing_columns) > 0) {
+ missing_list <- paste0("'", missing_columns, "'", collapse = ", ")
+ cli::cli_abort(
+ c(
+ "Query result missing required columns: {missing_list}",
+ "i" = "The query must return all original table columns (in any order)."
+ ),
+ class = "querychat_missing_columns_error"
+ )
+ }
+ }
+
+ df
+ },
+
+ #' @description
+ #' Get all data from the table
+ #'
+ #' @return A data frame containing all data
+ get_data = function() {
+ self$execute_query(NULL)
+ },
+
+ #' @description
+ #' Disconnect from the database
+ #'
+ #' @return NULL (invisibly)
+ cleanup = function() {
+ if (!is.null(private$conn) && DBI::dbIsValid(private$conn)) {
+ DBI::dbDisconnect(private$conn)
+ }
+ invisible(NULL)
+ }
+ )
+)
+
+
+get_schema_impl <- function(
+ conn,
+ table_name,
+ categorical_threshold = 20,
+ columns = NULL,
+ prep_query = identity
+) {
+ check_function(prep_query)
+
+ # Get column information
+ columns <- columns %||% DBI::dbListFields(conn, table_name)
+
+ schema_lines <- c(
+ paste("Table:", DBI::dbQuoteIdentifier(conn, table_name)),
+ "Columns:"
+ )
+
+ # Build single query to get column statistics
+ select_parts <- character(0)
+ numeric_columns <- character(0)
+ text_columns <- character(0)
+
+ # Get sample of data to determine types
+ sample_query <- paste0(
+ "SELECT * FROM ",
+ DBI::dbQuoteIdentifier(conn, table_name),
+ " LIMIT 1"
+ )
+ sample_data <- DBI::dbGetQuery(conn, prep_query(sample_query))
+
+ for (col in columns) {
+ col_class <- class(sample_data[[col]])[1]
+
+ if (
+ col_class %in%
+ c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt")
+ ) {
+ numeric_columns <- c(numeric_columns, col)
+ select_parts <- c(
+ select_parts,
+ paste0(
+ "MIN(",
+ DBI::dbQuoteIdentifier(conn, col),
+ ") as ",
+ DBI::dbQuoteIdentifier(conn, paste0(col, '__min'))
+ ),
+ paste0(
+ "MAX(",
+ DBI::dbQuoteIdentifier(conn, col),
+ ") as ",
+ DBI::dbQuoteIdentifier(conn, paste0(col, '__max'))
+ )
+ )
+ } else if (col_class %in% c("character", "factor")) {
+ text_columns <- c(text_columns, col)
+ select_parts <- c(
+ select_parts,
+ paste0(
+ "COUNT(DISTINCT ",
+ DBI::dbQuoteIdentifier(conn, col),
+ ") as ",
+ DBI::dbQuoteIdentifier(conn, paste0(col, '__distinct_count'))
+ )
+ )
+ }
+ }
+
+ # Execute statistics query
+ column_stats <- list()
+ if (length(select_parts) > 0) {
+ tryCatch(
+ {
+ stats_query <- paste0(
+ "SELECT ",
+ paste0(select_parts, collapse = ", "),
+ " FROM ",
+ DBI::dbQuoteIdentifier(conn, table_name)
+ )
+ result <- DBI::dbGetQuery(conn, prep_query(stats_query))
+ if (nrow(result) > 0) {
+ column_stats <- as.list(result[1, ])
+ }
+ },
+ error = function(e) {
+ # Fall back to no statistics if query fails
+ }
+ )
+ }
+
+ # Get categorical values for text columns below threshold
+ categorical_values <- list()
+ text_cols_to_query <- character(0)
+
+ for (col_name in text_columns) {
+ distinct_count_key <- paste0(col_name, "__distinct_count")
+ if (
+ distinct_count_key %in%
+ names(column_stats) &&
+ !is.na(column_stats[[distinct_count_key]]) &&
+ column_stats[[distinct_count_key]] <= categorical_threshold
+ ) {
+ text_cols_to_query <- c(text_cols_to_query, col_name)
+ }
+ }
+
+ # Remove duplicates
+ text_cols_to_query <- unique(text_cols_to_query)
+
+ # Get categorical values
+ if (length(text_cols_to_query) > 0) {
+ for (col_name in text_cols_to_query) {
+ tryCatch(
+ {
+ cat_query <- paste0(
+ "SELECT DISTINCT ",
+ DBI::dbQuoteIdentifier(conn, col_name),
+ " FROM ",
+ DBI::dbQuoteIdentifier(conn, table_name),
+ " WHERE ",
+ DBI::dbQuoteIdentifier(conn, col_name),
+ " IS NOT NULL ORDER BY ",
+ DBI::dbQuoteIdentifier(conn, col_name)
+ )
+ result <- DBI::dbGetQuery(conn, prep_query(cat_query))
+ if (nrow(result) > 0) {
+ categorical_values[[col_name]] <- result[[1]]
+ }
+ },
+ error = function(e) {
+ # Skip categorical values if query fails
+ }
+ )
+ }
+ }
+
+ # Build schema description
+ for (col in columns) {
+ col_class <- class(sample_data[[col]])[1]
+ sql_type <- r_class_to_sql_type(col_class)
+
+ column_info <- paste0("- ", col, " (", sql_type, ")")
+
+ # Add range info for numeric columns
+ if (col %in% numeric_columns) {
+ min_key <- paste0(col, "__min")
+ max_key <- paste0(col, "__max")
+ if (
+ min_key %in%
+ names(column_stats) &&
+ max_key %in% names(column_stats) &&
+ !is.na(column_stats[[min_key]]) &&
+ !is.na(column_stats[[max_key]])
+ ) {
+ range_info <- paste0(
+ " Range: ",
+ column_stats[[min_key]],
+ " to ",
+ column_stats[[max_key]]
+ )
+ column_info <- paste(column_info, range_info, sep = "\n")
+ }
+ }
+
+ # Add categorical values for text columns
+ if (col %in% names(categorical_values)) {
+ values <- categorical_values[[col]]
+ if (length(values) > 0) {
+ values_str <- paste0("'", values, "'", collapse = ", ")
+ cat_info <- paste0(" Categorical values: ", values_str)
+ column_info <- paste(column_info, cat_info, sep = "\n")
+ }
+ }
+
+ schema_lines <- c(schema_lines, column_info)
+ }
+
+ paste(schema_lines, collapse = "\n")
+}
+
+
+# nocov start
+# Map R classes to SQL types
+r_class_to_sql_type <- function(r_class) {
+ switch(
+ r_class,
+ "integer" = "INTEGER",
+ "numeric" = "FLOAT",
+ "double" = "FLOAT",
+ "logical" = "BOOLEAN",
+ "Date" = "DATE",
+ "POSIXct" = "TIMESTAMP",
+ "POSIXt" = "TIMESTAMP",
+ "character" = "TEXT",
+ "factor" = "TEXT",
+ "TEXT" # default
+ )
+}
+# nocov end
diff --git a/pkg-r/R/DataFrameSource.R b/pkg-r/R/DataFrameSource.R
new file mode 100644
index 00000000..2840abe0
--- /dev/null
+++ b/pkg-r/R/DataFrameSource.R
@@ -0,0 +1,128 @@
+#' Data Frame Source
+#'
+#' @description
+#' A DataSource implementation that wraps a data frame using DuckDB or SQLite
+#' for SQL query execution.
+#'
+#' @details
+#' This class creates an in-memory database connection and registers the
+#' provided data frame as a table. All SQL queries are executed against this
+#' database table. See [DBISource] for the full description of available
+#' methods.
+#'
+#' By default, DataFrameSource uses the first available engine from duckdb
+#' (checked first) or RSQLite. You can explicitly set the `engine` parameter to
+#' choose between "duckdb" or "sqlite", or set the global option
+#' `querychat.DataFrameSource.engine` to choose the default engine for all
+#' DataFrameSource instances. At least one of these packages must be installed.
+#'
+#' @export
+#' @examples
+#' \dontrun{
+#' # Create a data frame source (uses first available: duckdb or sqlite)
+#' df_source <- DataFrameSource$new(mtcars, "mtcars")
+#'
+#' # Get database type
+#' df_source$get_db_type() # Returns "DuckDB" or "SQLite"
+#'
+#' # Execute a query
+#' result <- df_source$execute_query("SELECT * FROM mtcars WHERE mpg > 25")
+#'
+#' # Explicitly choose an engine
+#' df_sqlite <- DataFrameSource$new(mtcars, "mtcars", engine = "sqlite")
+#'
+#' # Clean up when done
+#' df_source$cleanup()
+#' df_sqlite$cleanup()
+#' }
+DataFrameSource <- R6::R6Class(
+ "DataFrameSource",
+ inherit = DBISource,
+ private = list(
+ conn = NULL
+ ),
+ public = list(
+ #' @description
+ #' Create a new DataFrameSource
+ #'
+ #' @param df A data frame.
+ #' @param table_name Name to use for the table in SQL queries. Must be a
+ #' valid table name (start with letter, contain only letters, numbers,
+ #' and underscores)
+ #' @param engine Database engine to use: "duckdb" or "sqlite". Set the
+ #' global option `querychat.DataFrameSource.engine` to specify the default
+ #' engine for all instances. If NULL (default), uses the first available
+ #' engine from duckdb or RSQLite (in that order).
+ #' @return A new DataFrameSource object
+ #' @examples
+ #' \dontrun{
+ #' source <- DataFrameSource$new(iris, "iris")
+ #' }
+ initialize = function(
+ df,
+ table_name,
+ engine = getOption("querychat.DataFrameSource.engine", NULL)
+ ) {
+ check_data_frame(df)
+ check_sql_table_name(table_name)
+
+ engine <- engine %||% get_default_dataframe_engine()
+ engine <- tolower(engine)
+ arg_match(engine, c("duckdb", "sqlite"))
+
+ self$table_name <- table_name
+ private$colnames <- colnames(df)
+
+ # Create in-memory connection and register the data frame
+ if (engine == "duckdb") {
+ check_installed("duckdb")
+
+ private$conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:")
+
+ duckdb::duckdb_register(
+ private$conn,
+ table_name,
+ df,
+ experimental = FALSE
+ )
+
+ DBI::dbExecute(
+ private$conn,
+ r"(
+-- extensions: lock down supply chain + auto behaviors
+SET allow_community_extensions = false;
+SET allow_unsigned_extensions = false;
+SET autoinstall_known_extensions = false;
+SET autoload_known_extensions = false;
+
+-- external I/O: block file/database/network access from SQL
+SET enable_external_access = false;
+SET disabled_filesystems = 'LocalFileSystem';
+
+-- freeze configuration so user SQL can't relax anything
+SET lock_configuration = true;
+ )"
+ )
+ } else if (engine == "sqlite") {
+ check_installed("RSQLite")
+ private$conn <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
+ DBI::dbWriteTable(private$conn, table_name, df)
+ }
+ }
+ )
+)
+
+get_default_dataframe_engine <- function() {
+ if (is_installed("duckdb")) {
+ return("duckdb")
+ }
+ if (is_installed("RSQLite")) {
+ return("sqlite")
+ }
+ cli::cli_abort(c(
+ "No compatible database engine installed for DataFrameSource",
+ "i" = "Install either {.pkg duckdb} or {.pkg RSQLite}:",
+ " " = "{.run install.packages(\"duckdb\")}",
+ " " = "{.run install.packages(\"RSQLite\")}"
+ ))
+}
diff --git a/pkg-r/R/DataSource.R b/pkg-r/R/DataSource.R
index 33253560..76d5b5d4 100644
--- a/pkg-r/R/DataSource.R
+++ b/pkg-r/R/DataSource.R
@@ -90,315 +90,6 @@ DataSource <- R6::R6Class(
)
)
-
-#' Data Frame Source
-#'
-#' @description
-#' A DataSource implementation that wraps a data frame using DuckDB or SQLite
-#' for SQL query execution.
-#'
-#' @details
-#' This class creates an in-memory database connection and registers the
-#' provided data frame as a table. All SQL queries are executed against this
-#' database table. See [DBISource] for the full description of available
-#' methods.
-#'
-#' By default, DataFrameSource uses the first available engine from duckdb
-#' (checked first) or RSQLite. You can explicitly set the `engine` parameter to
-#' choose between "duckdb" or "sqlite", or set the global option
-#' `querychat.DataFrameSource.engine` to choose the default engine for all
-#' DataFrameSource instances. At least one of these packages must be installed.
-#'
-#' @export
-#' @examples
-#' \dontrun{
-#' # Create a data frame source (uses first available: duckdb or sqlite)
-#' df_source <- DataFrameSource$new(mtcars, "mtcars")
-#'
-#' # Get database type
-#' df_source$get_db_type() # Returns "DuckDB" or "SQLite"
-#'
-#' # Execute a query
-#' result <- df_source$execute_query("SELECT * FROM mtcars WHERE mpg > 25")
-#'
-#' # Explicitly choose an engine
-#' df_sqlite <- DataFrameSource$new(mtcars, "mtcars", engine = "sqlite")
-#'
-#' # Clean up when done
-#' df_source$cleanup()
-#' df_sqlite$cleanup()
-#' }
-DataFrameSource <- R6::R6Class(
- "DataFrameSource",
- inherit = DBISource,
- private = list(
- conn = NULL
- ),
- public = list(
- #' @description
- #' Create a new DataFrameSource
- #'
- #' @param df A data frame.
- #' @param table_name Name to use for the table in SQL queries. Must be a
- #' valid table name (start with letter, contain only letters, numbers,
- #' and underscores)
- #' @param engine Database engine to use: "duckdb" or "sqlite". Set the
- #' global option `querychat.DataFrameSource.engine` to specify the default
- #' engine for all instances. If NULL (default), uses the first available
- #' engine from duckdb or RSQLite (in that order).
- #' @return A new DataFrameSource object
- #' @examples
- #' \dontrun{
- #' source <- DataFrameSource$new(iris, "iris")
- #' }
- initialize = function(
- df,
- table_name,
- engine = getOption("querychat.DataFrameSource.engine", NULL)
- ) {
- check_data_frame(df)
- check_sql_table_name(table_name)
-
- engine <- engine %||% get_default_dataframe_engine()
- engine <- tolower(engine)
- arg_match(engine, c("duckdb", "sqlite"))
-
- self$table_name <- table_name
- private$colnames <- colnames(df)
-
- # Create in-memory connection and register the data frame
- if (engine == "duckdb") {
- check_installed("duckdb")
-
- private$conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:")
-
- duckdb::duckdb_register(
- private$conn,
- table_name,
- df,
- experimental = FALSE
- )
-
- DBI::dbExecute(
- private$conn,
- r"(
--- extensions: lock down supply chain + auto behaviors
-SET allow_community_extensions = false;
-SET allow_unsigned_extensions = false;
-SET autoinstall_known_extensions = false;
-SET autoload_known_extensions = false;
-
--- external I/O: block file/database/network access from SQL
-SET enable_external_access = false;
-SET disabled_filesystems = 'LocalFileSystem';
-
--- freeze configuration so user SQL can't relax anything
-SET lock_configuration = true;
- )"
- )
- } else if (engine == "sqlite") {
- check_installed("RSQLite")
- private$conn <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
- DBI::dbWriteTable(private$conn, table_name, df)
- }
- }
- )
-)
-
-
-#' DBI Source
-#'
-#' @description
-#' A DataSource implementation for DBI database connections (SQLite, PostgreSQL,
-#' MySQL, etc.).
-#'
-#' @details
-#' This class wraps a DBI connection and provides SQL query execution against
-#' a specified table in the database.
-#'
-#' @export
-#' @examples
-#' \dontrun{
-#' # Connect to a database
-#' conn <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
-#' DBI::dbWriteTable(conn, "mtcars", mtcars)
-#'
-#' # Create a DBI source
-#' db_source <- DBISource$new(conn, "mtcars")
-#'
-#' # Get database type
-#' db_source$get_db_type() # Returns "SQLite"
-#'
-#' # Execute a query
-#' result <- db_source$execute_query("SELECT * FROM mtcars WHERE mpg > 25")
-#'
-#' # Note: cleanup() will disconnect the connection
-#' # If you want to keep the connection open, don't call cleanup()
-#' }
-DBISource <- R6::R6Class(
- "DBISource",
- inherit = DataSource,
- private = list(
- conn = NULL
- ),
- public = list(
- #' @description
- #' Create a new DBISource
- #'
- #' @param conn A DBI connection object
- #' @param table_name Name of the table in the database. Can be a character
- #' string or a [DBI::Id()] object for tables in catalogs/schemas
- #' @return A new DBISource object
- #' @examples
- #' \dontrun{
- #' conn <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
- #' DBI::dbWriteTable(conn, "iris", iris)
- #' source <- DBISource$new(conn, "iris")
- #' }
- initialize = function(conn, table_name) {
- if (!inherits(conn, "DBIConnection")) {
- cli::cli_abort(
- "{.arg conn} must be a {.cls DBIConnection}, not {.obj_type_friendly {conn}}"
- )
- }
-
- # Validate table_name type
- if (inherits(table_name, "Id")) {
- # DBI::Id object - keep as is
- } else if (is.character(table_name) && length(table_name) == 1) {
- # Character string - keep as is
- } else {
- cli::cli_abort(
- "{.arg table_name} must be a single character string or a {.fn DBI::Id} object"
- )
- }
-
- # Check if table exists
- if (!DBI::dbExistsTable(conn, table_name)) {
- cli::cli_abort(c(
- "Table {.val {DBI::dbQuoteIdentifier(conn, table_name)}} not found in database",
- "i" = "If you're using a table in a catalog or schema, pass a {.fn DBI::Id} object to {.arg table_name}"
- ))
- }
-
- private$conn <- conn
- self$table_name <- table_name
-
- # Store original column names for validation
- private$colnames <- colnames(DBI::dbGetQuery(
- conn,
- sprintf(
- "SELECT * FROM %s LIMIT 0",
- DBI::dbQuoteIdentifier(conn, table_name)
- )
- ))
- },
-
- #' @description Get the database type
- #' @return A string identifying the database type
- get_db_type = function() {
- # Special handling for known database types
- if (inherits(private$conn, "duckdb_connection")) {
- return("DuckDB")
- }
- if (inherits(private$conn, "SQLiteConnection")) {
- return("SQLite")
- }
-
- # Default to 'POSIX' if dbms name not found
- conn_info <- DBI::dbGetInfo(private$conn)
- dbms_name <- getElement(conn_info, "dbms.name") %||% "POSIX"
-
- # Remove ' SQL', if exists (SQL is already in the prompt)
- gsub(" SQL", "", dbms_name)
- },
-
- #' @description
- #' Get schema information for the database table
- #'
- #' @param categorical_threshold Maximum number of unique values for a text
- #' column to be considered categorical (default: 20)
- #' @return A string describing the schema
- get_schema = function(categorical_threshold = 20) {
- check_number_whole(categorical_threshold, min = 1)
- get_schema_impl(private$conn, self$table_name, categorical_threshold)
- },
-
- #' @description
- #' Execute a SQL query
- #'
- #' @param query SQL query string. If NULL or empty, returns all data
- #' @return A data frame with query results
- execute_query = function(query) {
- check_string(query, allow_null = TRUE, allow_empty = TRUE)
- if (is.null(query) || !nzchar(query)) {
- query <- paste0(
- "SELECT * FROM ",
- DBI::dbQuoteIdentifier(private$conn, self$table_name)
- )
- }
-
- check_query(query)
- DBI::dbGetQuery(private$conn, query)
- },
-
- #' @description
- #' Test a SQL query by fetching only one row
- #'
- #' @param query SQL query string
- #' @param require_all_columns If TRUE, validates that the result includes
- #' all original table columns (default: FALSE)
- #' @return A data frame with one row of results
- test_query = function(query, require_all_columns = FALSE) {
- check_string(query)
- check_bool(require_all_columns)
- check_query(query)
-
- rs <- DBI::dbSendQuery(private$conn, query)
- df <- DBI::dbFetch(rs, n = 1)
- DBI::dbClearResult(rs)
-
- if (require_all_columns) {
- result_columns <- names(df)
- missing_columns <- setdiff(private$colnames, result_columns)
-
- if (length(missing_columns) > 0) {
- missing_list <- paste0("'", missing_columns, "'", collapse = ", ")
- cli::cli_abort(
- c(
- "Query result missing required columns: {missing_list}",
- "i" = "The query must return all original table columns (in any order)."
- ),
- class = "querychat_missing_columns_error"
- )
- }
- }
-
- df
- },
-
- #' @description
- #' Get all data from the table
- #'
- #' @return A data frame containing all data
- get_data = function() {
- self$execute_query(NULL)
- },
-
- #' @description
- #' Disconnect from the database
- #'
- #' @return NULL (invisibly)
- cleanup = function() {
- if (!is.null(private$conn) && DBI::dbIsValid(private$conn)) {
- DBI::dbDisconnect(private$conn)
- }
- invisible(NULL)
- }
- )
-)
-
-
# Helper Functions -------------------------------------------------------------
#' Check if object is a DataSource
@@ -409,211 +100,3 @@ DBISource <- R6::R6Class(
is_data_source <- function(x) {
inherits(x, "DataSource")
}
-
-
-get_default_dataframe_engine <- function() {
- if (is_installed("duckdb")) {
- return("duckdb")
- }
- if (is_installed("RSQLite")) {
- return("sqlite")
- }
- cli::cli_abort(c(
- "No compatible database engine installed for DataFrameSource",
- "i" = "Install either {.pkg duckdb} or {.pkg RSQLite}:",
- " " = "{.run install.packages(\"duckdb\")}",
- " " = "{.run install.packages(\"RSQLite\")}"
- ))
-}
-
-
-get_schema_impl <- function(conn, table_name, categorical_threshold = 20) {
- # Get column information
- columns <- DBI::dbListFields(conn, table_name)
-
- schema_lines <- c(
- paste("Table:", DBI::dbQuoteIdentifier(conn, table_name)),
- "Columns:"
- )
-
- # Build single query to get column statistics
- select_parts <- character(0)
- numeric_columns <- character(0)
- text_columns <- character(0)
-
- # Get sample of data to determine types
- sample_query <- paste0(
- "SELECT * FROM ",
- DBI::dbQuoteIdentifier(conn, table_name),
- " LIMIT 1"
- )
- sample_data <- DBI::dbGetQuery(conn, sample_query)
-
- for (col in columns) {
- col_class <- class(sample_data[[col]])[1]
-
- if (
- col_class %in%
- c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt")
- ) {
- numeric_columns <- c(numeric_columns, col)
- select_parts <- c(
- select_parts,
- paste0(
- "MIN(",
- DBI::dbQuoteIdentifier(conn, col),
- ") as ",
- DBI::dbQuoteIdentifier(conn, paste0(col, '__min'))
- ),
- paste0(
- "MAX(",
- DBI::dbQuoteIdentifier(conn, col),
- ") as ",
- DBI::dbQuoteIdentifier(conn, paste0(col, '__max'))
- )
- )
- } else if (col_class %in% c("character", "factor")) {
- text_columns <- c(text_columns, col)
- select_parts <- c(
- select_parts,
- paste0(
- "COUNT(DISTINCT ",
- DBI::dbQuoteIdentifier(conn, col),
- ") as ",
- DBI::dbQuoteIdentifier(conn, paste0(col, '__distinct_count'))
- )
- )
- }
- }
-
- # Execute statistics query
- column_stats <- list()
- if (length(select_parts) > 0) {
- tryCatch(
- {
- stats_query <- paste0(
- "SELECT ",
- paste0(select_parts, collapse = ", "),
- " FROM ",
- DBI::dbQuoteIdentifier(conn, table_name)
- )
- result <- DBI::dbGetQuery(conn, stats_query)
- if (nrow(result) > 0) {
- column_stats <- as.list(result[1, ])
- }
- },
- error = function(e) {
- # Fall back to no statistics if query fails
- }
- )
- }
-
- # Get categorical values for text columns below threshold
- categorical_values <- list()
- text_cols_to_query <- character(0)
-
- for (col_name in text_columns) {
- distinct_count_key <- paste0(col_name, "__distinct_count")
- if (
- distinct_count_key %in%
- names(column_stats) &&
- !is.na(column_stats[[distinct_count_key]]) &&
- column_stats[[distinct_count_key]] <= categorical_threshold
- ) {
- text_cols_to_query <- c(text_cols_to_query, col_name)
- }
- }
-
- # Remove duplicates
- text_cols_to_query <- unique(text_cols_to_query)
-
- # Get categorical values
- if (length(text_cols_to_query) > 0) {
- for (col_name in text_cols_to_query) {
- tryCatch(
- {
- cat_query <- paste0(
- "SELECT DISTINCT ",
- DBI::dbQuoteIdentifier(conn, col_name),
- " FROM ",
- DBI::dbQuoteIdentifier(conn, table_name),
- " WHERE ",
- DBI::dbQuoteIdentifier(conn, col_name),
- " IS NOT NULL ORDER BY ",
- DBI::dbQuoteIdentifier(conn, col_name)
- )
- result <- DBI::dbGetQuery(conn, cat_query)
- if (nrow(result) > 0) {
- categorical_values[[col_name]] <- result[[1]]
- }
- },
- error = function(e) {
- # Skip categorical values if query fails
- }
- )
- }
- }
-
- # Build schema description
- for (col in columns) {
- col_class <- class(sample_data[[col]])[1]
- sql_type <- r_class_to_sql_type(col_class)
-
- column_info <- paste0("- ", col, " (", sql_type, ")")
-
- # Add range info for numeric columns
- if (col %in% numeric_columns) {
- min_key <- paste0(col, "__min")
- max_key <- paste0(col, "__max")
- if (
- min_key %in%
- names(column_stats) &&
- max_key %in% names(column_stats) &&
- !is.na(column_stats[[min_key]]) &&
- !is.na(column_stats[[max_key]])
- ) {
- range_info <- paste0(
- " Range: ",
- column_stats[[min_key]],
- " to ",
- column_stats[[max_key]]
- )
- column_info <- paste(column_info, range_info, sep = "\n")
- }
- }
-
- # Add categorical values for text columns
- if (col %in% names(categorical_values)) {
- values <- categorical_values[[col]]
- if (length(values) > 0) {
- values_str <- paste0("'", values, "'", collapse = ", ")
- cat_info <- paste0(" Categorical values: ", values_str)
- column_info <- paste(column_info, cat_info, sep = "\n")
- }
- }
-
- schema_lines <- c(schema_lines, column_info)
- }
-
- paste(schema_lines, collapse = "\n")
-}
-
-
-# nocov start
-# Map R classes to SQL types
-r_class_to_sql_type <- function(r_class) {
- switch(
- r_class,
- "integer" = "INTEGER",
- "numeric" = "FLOAT",
- "double" = "FLOAT",
- "logical" = "BOOLEAN",
- "Date" = "DATE",
- "POSIXct" = "TIMESTAMP",
- "POSIXt" = "TIMESTAMP",
- "character" = "TEXT",
- "factor" = "TEXT",
- "TEXT" # default
- )
-}
-# nocov end
diff --git a/pkg-r/R/QueryChat.R b/pkg-r/R/QueryChat.R
index cf8c7483..149d5deb 100644
--- a/pkg-r/R/QueryChat.R
+++ b/pkg-r/R/QueryChat.R
@@ -142,9 +142,9 @@ QueryChat <- R6::R6Class(
#'
#' # With database
#' library(DBI)
- #' conn <- dbConnect(RSQLite::SQLite(), ":memory:")
- #' dbWriteTable(conn, "mtcars", mtcars)
- #' qc <- QueryChat$new(conn, "mtcars")
+ #' con <- dbConnect(RSQLite::SQLite(), ":memory:")
+ #' dbWriteTable(con, "mtcars", mtcars)
+ #' qc <- QueryChat$new(con, "mtcars")
#' }
initialize = function(
data_source,
@@ -172,8 +172,10 @@ QueryChat <- R6::R6Class(
check_string(prompt_template, allow_null = TRUE)
check_bool(cleanup, allow_na = TRUE)
- if (is_missing(table_name) && is.data.frame(data_source)) {
- table_name <- deparse1(substitute(data_source))
+ if (is_missing(table_name)) {
+ if (is.data.frame(data_source) || inherits(data_source, "tbl_sql")) {
+ table_name <- deparse1(substitute(data_source))
+ }
}
private$.data_source <- normalize_data_source(data_source, table_name)
@@ -427,8 +429,14 @@ QueryChat <- R6::R6Class(
})
output$dt <- DT::renderDT({
+ df <- qc_vals$df()
+ if (inherits(df, "tbl_sql")) {
+ # Materialize the query for DT, {dplyr} guaranteed by TblSqlSource
+ df <- dplyr::collect(df)
+ }
+
DT::datatable(
- qc_vals$df(),
+ df,
fillContainer = TRUE,
options = list(pageLength = 25, scrollX = TRUE)
)
@@ -740,9 +748,9 @@ QueryChat <- R6::R6Class(
#'
#' # Chat with a database table (table_name required)
#' library(DBI)
-#' conn <- dbConnect(RSQLite::SQLite(), ":memory:")
-#' dbWriteTable(conn, "mtcars", mtcars)
-#' querychat_app(conn, "mtcars")
+#' con <- dbConnect(RSQLite::SQLite(), ":memory:")
+#' dbWriteTable(con, "mtcars", mtcars)
+#' querychat_app(con, "mtcars")
#'
#' # Create QueryChat class object
#' qc <- querychat(mtcars)
@@ -765,8 +773,10 @@ querychat <- function(
prompt_template = NULL,
cleanup = NA
) {
- if (is_missing(table_name) && is.data.frame(data_source)) {
- table_name <- deparse1(substitute(data_source))
+ if (is_missing(table_name)) {
+ if (is.data.frame(data_source) || inherits(data_source, "tbl_sql")) {
+ table_name <- deparse1(substitute(data_source))
+ }
}
QueryChat$new(
@@ -853,6 +863,10 @@ normalize_data_source <- function(data_source, table_name) {
return(DataFrameSource$new(data_source, table_name))
}
+ if (inherits(data_source, "tbl_sql")) {
+ return(TblSqlSource$new(data_source, table_name))
+ }
+
if (inherits(data_source, "DBIConnection")) {
return(DBISource$new(data_source, table_name))
}
diff --git a/pkg-r/R/QueryChatSystemPrompt.R b/pkg-r/R/QueryChatSystemPrompt.R
index 4f40c472..52e99fb5 100644
--- a/pkg-r/R/QueryChatSystemPrompt.R
+++ b/pkg-r/R/QueryChatSystemPrompt.R
@@ -88,7 +88,8 @@ QueryChatSystemPrompt <- R6::R6Class(
data_description = self$data_description,
extra_instructions = self$extra_instructions,
has_tool_update = if ("update" %in% tools) "true",
- has_tool_query = if ("query" %in% tools) "true"
+ has_tool_query = if ("query" %in% tools) "true",
+ include_query_guidelines = if (length(tools) > 0) "true"
)
whisker::whisker.render(self$template, context)
diff --git a/pkg-r/R/TblSqlSource.R b/pkg-r/R/TblSqlSource.R
new file mode 100644
index 00000000..3b6210af
--- /dev/null
+++ b/pkg-r/R/TblSqlSource.R
@@ -0,0 +1,174 @@
+#' Data Source: SQL Tibble
+#'
+#' @description
+#' A DataSource implementation for lazy SQL tibbles connected to databases via
+#' [dbplyr::tbl_sql()] or [dplyr::sql()].
+#'
+#' @examplesIf rlang::is_interactive() && rlang::is_installed("dbplyr") && rlang::is_installed("dplyr") && rlang::is_installed("duckdb")
+#' con <- DBI::dbConnect(duckdb::duckdb())
+#' DBI::dbWriteTable(con, "mtcars", mtcars)
+#'
+#' mtcars_source <- TblSqlSource$new(dplyr::tbl(con, "mtcars"))
+#' mtcars_source$get_db_type() # "DuckDB"
+#'
+#' result <- mtcars_source$execute_query("SELECT * FROM mtcars WHERE cyl > 4")
+#'
+#' # Note, the result is not the *full* data frame, but a lazy SQL tibble
+#' result
+#'
+#' # You can chain this result into a dplyr pipeline
+#' dplyr::count(result, cyl, gear)
+#'
+#' # Or collect the entire data frame into local memory
+#' dplyr::collect(result)
+#'
+#' # Finally, clean up when done with the database (closes the DB connection)
+#' mtcars_source$cleanup()
+#'
+#' @export
+TblSqlSource <- R6::R6Class(
+ "TblSqlSource",
+ inherit = DBISource,
+ private = list(
+ tbl = NULL,
+ tbl_cte = NULL
+ ),
+ public = list(
+ #' @field table_name Name of the table to be used in SQL queries
+ table_name = NULL,
+
+ #' @description
+ #' Create a new TblSqlSource
+ #'
+ #' @param tbl A [dbplyr::tbl_sql()] (or SQL tibble via [dplyr::tbl()]).
+ #' @param table_name Name of the table in the database. Can be a character
+ #' string, or will be inferred from the `tbl` argument, if possible.
+ #' @return A new TblSqlSource object
+ #' @examplesIf rlang::is_interactive() && rlang::is_installed("dbplyr") && rlang::is_installed("dplyr") && rlang::is_installed("RSQLite")
+ #' con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
+ #' DBI::dbWriteTable(con, "mtcars", mtcars)
+ #' source <- TblSqlSource$new(dplyr::tbl(con, "mtcars"))
+ initialize = function(tbl, table_name = missing_arg()) {
+ check_installed("dbplyr")
+ check_installed("dplyr")
+
+ if (!inherits(tbl, "tbl_sql")) {
+ cli::cli_abort(
+ "{.arg tbl} must be a SQL tibble connected to a database, not {.obj_type_friendly {tbl}}"
+ )
+ }
+
+ private$conn <- dbplyr::remote_con(tbl)
+ private$tbl <- tbl
+
+ # Collect various signals to infer the table name
+ obj_name <- deparse1(substitute(tbl))
+
+ # Get the exact table name, if tbl directly references a single table
+ remote_name <- dbplyr::remote_name(private$tbl)
+
+ use_cte <- FALSE
+
+ if (!is_missing(table_name)) {
+ check_sql_table_name(table_name)
+ self$table_name <- table_name
+ use_cte <- !identical(table_name, remote_name)
+ } else if (!is.null(remote_name)) {
+ # Remote name is non-NULL when it points to a table, so we use that next
+ self$table_name <- remote_name
+ use_cte <- FALSE
+ } else if (is_valid_sql_table_name(obj_name)) {
+ self$table_name <- obj_name
+ use_cte <- TRUE
+ } else {
+ id <- as.integer(runif(1) * 1e6)
+ self$table_name <- sprintf("querychat_cte_%d", id)
+ use_cte <- TRUE
+ }
+
+ if (use_cte) {
+ # We received a complicated tbl expression, we'll have to use a CTE
+ private$tbl_cte <- dbplyr::remote_query(private$tbl)
+ }
+ },
+
+ #' @description
+ #' Get the database type
+ #'
+ #' @return A string describing the database type (e.g., "DuckDB", "SQLite")
+ get_db_type = function() {
+ super$get_db_type()
+ },
+
+ #' @description
+ #' Get schema information about the table
+ #'
+ #' @param categorical_threshold Maximum number of unique values for a text
+ #' column to be considered categorical
+ #' @return A string containing schema information formatted for LLM prompts
+ get_schema = function(categorical_threshold = 20) {
+ get_schema_impl(
+ private$conn,
+ self$table_name,
+ categorical_threshold,
+ columns = colnames(private$tbl),
+ prep_query = self$prep_query
+ )
+ },
+
+ #' @description
+ #' Execute a SQL query and return results
+ #'
+ #' @param query SQL query string to execute
+ #' @return A data frame containing query results
+ execute_query = function(query) {
+ sql_query <- self$prep_query(query)
+ dplyr::tbl(private$conn, dplyr::sql(sql_query))
+ },
+
+ #' @description
+ #' Test a SQL query by fetching only one row
+ #'
+ #' @param query SQL query string to test
+ #' @return A data frame containing one row of results (or empty if no matches)
+ test_query = function(query) {
+ super$test_query(self$prep_query(query))
+ },
+
+ #' @description
+ #' Prepare a generic `SELECT * FROM ____` query to work with the SQL tibble
+ #'
+ #' @param query SQL query as a string
+ #' @return A complete SQL query string
+ prep_query = function(query) {
+ check_string(query)
+
+ if (is.null(private$tbl_cte)) {
+ return(query)
+ }
+
+ sprintf(
+ "WITH %s AS (\n%s\n)\n%s",
+ DBI::dbQuoteIdentifier(private$conn, self$table_name),
+ private$tbl_cte,
+ query
+ )
+ },
+
+ #' @description
+ #' Get the unfiltered data as a SQL tibble
+ #'
+ #' @return A [dbplyr::tbl_sql()] containing the original, unfiltered data
+ get_data = function() {
+ private$tbl
+ },
+
+ #' @description
+ #' Clean up resources (close connections, etc.)
+ #'
+ #' @return NULL (invisibly)
+ cleanup = function() {
+ super$cleanup()
+ }
+ )
+)
diff --git a/pkg-r/R/utils-check.R b/pkg-r/R/utils-check.R
index d6fb810c..85e8c25d 100644
--- a/pkg-r/R/utils-check.R
+++ b/pkg-r/R/utils-check.R
@@ -47,7 +47,7 @@ check_sql_table_name <- function(
check_string(x, allow_null = allow_null, arg = arg, call = call)
# Then validate SQL table name pattern
- if (!grepl("^[a-zA-Z][a-zA-Z0-9_]*$", x)) {
+ if (!is_valid_sql_table_name(x)) {
cli::cli_abort(
c(
"{.arg {arg}} must be a valid SQL table name",
@@ -61,6 +61,10 @@ check_sql_table_name <- function(
invisible(NULL)
}
+is_valid_sql_table_name <- function(x) {
+ grepl("^[a-zA-Z][a-zA-Z0-9_]*$", x)
+}
+
# SQL query validation --------------------------------------------------------
diff --git a/pkg-r/inst/prompts/prompt.md b/pkg-r/inst/prompts/prompt.md
index 6a92101c..7c8ea5a1 100644
--- a/pkg-r/inst/prompts/prompt.md
+++ b/pkg-r/inst/prompts/prompt.md
@@ -16,6 +16,44 @@ Here is additional information about the data:
For security reasons, you may only query this specific table.
+{{#include_query_guidelines}}
+## SQL Query Guidelines
+
+When writing SQL queries to interact with the database, please adhere to the following guidelines to ensure compatibility and correctness.
+
+### Structural Rules
+
+**No trailing semicolons**
+Never end your query with a semicolon (`;`). The parent query needs to continue after your subquery closes.
+
+**Single statement only**
+Return exactly one `SELECT` statement. Do not include multiple statements separated by semicolons.
+
+**No procedural or meta statements**
+Do not include:
+- `EXPLAIN` / `EXPLAIN ANALYZE`
+- `SET` statements
+- Variable declarations
+- Transaction controls (`BEGIN`, `COMMIT`, `ROLLBACK`)
+- DDL statements (`CREATE`, `ALTER`, `DROP`)
+- `INTO` clauses (e.g., `SELECT INTO`)
+- Locking hints (`FOR UPDATE`, `FOR SHARE`)
+
+### Column Naming Rules
+
+**Alias all computed/derived columns**
+Every expression that isn't a simple column reference must have an explicit alias.
+
+**Ensure unique column names**
+The result set must not have duplicate column names, even when selecting from multiple tables.
+
+**Avoid `SELECT *` with JOINs**
+Explicitly list columns to prevent duplicate column names and ensure a predictable output schema.
+
+**Avoid reserved words as unquoted aliases**
+If using reserved words as column aliases, quote them appropriately for your dialect.
+
+{{/include_query_guidelines}}
{{#is_duck_db}}
### DuckDB SQL Tips
@@ -130,7 +168,7 @@ You might want to explore the advanced features
- The user has asked a very specific question requiring only a direct answer
- The conversation is clearly wrapping up
-#### Guidelines
+#### Suggestion Guidelines
- Suggestions can appear **anywhere** in your response—not just at the end
- Use list format at the end for 2-4 follow-up options (most common pattern)
@@ -141,7 +179,6 @@ You might want to explore the advanced features
- Never use generic phrases like "If you'd like to..." or "Would you like to explore..." — instead, provide concrete suggestions
- Never refer to suggestions as "prompts" – call them "suggestions" or "ideas" or similar
-
## Important Guidelines
- **Ask for clarification** if any request is unclear or ambiguous
diff --git a/pkg-r/man/DBISource.Rd b/pkg-r/man/DBISource.Rd
index 202e71b0..ae6b05aa 100644
--- a/pkg-r/man/DBISource.Rd
+++ b/pkg-r/man/DBISource.Rd
@@ -1,24 +1,21 @@
% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/DataSource.R
+% Please edit documentation in R/DBISource.R
\name{DBISource}
\alias{DBISource}
\title{DBI Source}
\description{
A DataSource implementation for DBI database connections (SQLite, PostgreSQL,
-MySQL, etc.).
-}
-\details{
-This class wraps a DBI connection and provides SQL query execution against
-a specified table in the database.
+MySQL, etc.). This class wraps a DBI connection and provides SQL query
+execution against a single table in the database.
}
\examples{
\dontrun{
# Connect to a database
-conn <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
-DBI::dbWriteTable(conn, "mtcars", mtcars)
+con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
+DBI::dbWriteTable(con, "mtcars", mtcars)
# Create a DBI source
-db_source <- DBISource$new(conn, "mtcars")
+db_source <- DBISource$new(con, "mtcars")
# Get database type
db_source$get_db_type() # Returns "SQLite"
@@ -35,9 +32,9 @@ result <- db_source$execute_query("SELECT * FROM mtcars WHERE mpg > 25")
## ------------------------------------------------
\dontrun{
-conn <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
-DBI::dbWriteTable(conn, "iris", iris)
-source <- DBISource$new(conn, "iris")
+con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
+DBI::dbWriteTable(con, "iris", iris)
+source <- DBISource$new(con, "iris")
}
}
\section{Super class}{
@@ -81,9 +78,9 @@ A new DBISource object
\subsection{Examples}{
\if{html}{\out{
}}
\preformatted{\dontrun{
-conn <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
-DBI::dbWriteTable(conn, "iris", iris)
-source <- DBISource$new(conn, "iris")
+con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
+DBI::dbWriteTable(con, "iris", iris)
+source <- DBISource$new(con, "iris")
}
}
\if{html}{\out{
}}
diff --git a/pkg-r/man/DataFrameSource.Rd b/pkg-r/man/DataFrameSource.Rd
index b33d1e66..cefd2fa4 100644
--- a/pkg-r/man/DataFrameSource.Rd
+++ b/pkg-r/man/DataFrameSource.Rd
@@ -1,5 +1,5 @@
% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/DataSource.R
+% Please edit documentation in R/DataFrameSource.R
\name{DataFrameSource}
\alias{DataFrameSource}
\title{Data Frame Source}
diff --git a/pkg-r/man/QueryChat.Rd b/pkg-r/man/QueryChat.Rd
index 64f15b0a..eb873b7e 100644
--- a/pkg-r/man/QueryChat.Rd
+++ b/pkg-r/man/QueryChat.Rd
@@ -86,9 +86,9 @@ qc <- QueryChat$new(
# With database
library(DBI)
-conn <- dbConnect(RSQLite::SQLite(), ":memory:")
-dbWriteTable(conn, "mtcars", mtcars)
-qc <- QueryChat$new(conn, "mtcars")
+con <- dbConnect(RSQLite::SQLite(), ":memory:")
+dbWriteTable(con, "mtcars", mtcars)
+qc <- QueryChat$new(con, "mtcars")
}
## ------------------------------------------------
@@ -311,9 +311,9 @@ qc <- QueryChat$new(
# With database
library(DBI)
-conn <- dbConnect(RSQLite::SQLite(), ":memory:")
-dbWriteTable(conn, "mtcars", mtcars)
-qc <- QueryChat$new(conn, "mtcars")
+con <- dbConnect(RSQLite::SQLite(), ":memory:")
+dbWriteTable(con, "mtcars", mtcars)
+qc <- QueryChat$new(con, "mtcars")
}
}
\if{html}{\out{}}
diff --git a/pkg-r/man/TblSqlSource.Rd b/pkg-r/man/TblSqlSource.Rd
new file mode 100644
index 00000000..83e2b84d
--- /dev/null
+++ b/pkg-r/man/TblSqlSource.Rd
@@ -0,0 +1,222 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/TblSqlSource.R
+\name{TblSqlSource}
+\alias{TblSqlSource}
+\title{Data Source: SQL Tibble}
+\description{
+A DataSource implementation for lazy SQL tibbles connected to databases via
+\code{\link[dbplyr:tbl_sql]{dbplyr::tbl_sql()}} or \code{\link[dplyr:sql]{dplyr::sql()}}.
+}
+\examples{
+\dontshow{if (rlang::is_interactive() && rlang::is_installed("dbplyr") && rlang::is_installed("dplyr") && rlang::is_installed("duckdb")) withAutoprint(\{ # examplesIf}
+con <- DBI::dbConnect(duckdb::duckdb())
+DBI::dbWriteTable(con, "mtcars", mtcars)
+
+mtcars_source <- TblSqlSource$new(dplyr::tbl(con, "mtcars"))
+mtcars_source$get_db_type() # "DuckDB"
+
+result <- mtcars_source$execute_query("SELECT * FROM mtcars WHERE cyl > 4")
+
+# Note, the result is not the *full* data frame, but a lazy SQL tibble
+result
+
+# You can chain this result into a dplyr pipeline
+dplyr::count(result, cyl, gear)
+
+# Or collect the entire data frame into local memory
+dplyr::collect(result)
+
+# Finally, clean up when done with the database (closes the DB connection)
+mtcars_source$cleanup()
+\dontshow{\}) # examplesIf}
+\dontshow{if (rlang::is_interactive() && rlang::is_installed("dbplyr") && rlang::is_installed("dplyr") && rlang::is_installed("RSQLite")) withAutoprint(\{ # examplesIf}
+con <- DBI::dbConnect(RSQLite::SQLite(), ":memory:")
+DBI::dbWriteTable(con, "mtcars", mtcars)
+source <- TblSqlSource$new(dplyr::tbl(con, "mtcars"))
+\dontshow{\}) # examplesIf}
+}
+\section{Super classes}{
+\code{\link[querychat:DataSource]{querychat::DataSource}} -> \code{\link[querychat:DBISource]{querychat::DBISource}} -> \code{TblSqlSource}
+}
+\section{Public fields}{
+\if{html}{\out{}}
+\describe{
+\item{\code{table_name}}{Name of the table to be used in SQL queries}
+}
+\if{html}{\out{
}}
+}
+\section{Methods}{
+\subsection{Public methods}{
+\itemize{
+\item \href{#method-TblSqlSource-new}{\code{TblSqlSource$new()}}
+\item \href{#method-TblSqlSource-get_db_type}{\code{TblSqlSource$get_db_type()}}
+\item \href{#method-TblSqlSource-get_schema}{\code{TblSqlSource$get_schema()}}
+\item \href{#method-TblSqlSource-execute_query}{\code{TblSqlSource$execute_query()}}
+\item \href{#method-TblSqlSource-test_query}{\code{TblSqlSource$test_query()}}
+\item \href{#method-TblSqlSource-prep_query}{\code{TblSqlSource$prep_query()}}
+\item \href{#method-TblSqlSource-get_data}{\code{TblSqlSource$get_data()}}
+\item \href{#method-TblSqlSource-cleanup}{\code{TblSqlSource$cleanup()}}
+\item \href{#method-TblSqlSource-clone}{\code{TblSqlSource$clone()}}
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-new}{}}}
+\subsection{Method \code{new()}}{
+Create a new TblSqlSource
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$new(tbl, table_name = missing_arg())}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{tbl}}{A \code{\link[dbplyr:tbl_sql]{dbplyr::tbl_sql()}} (or SQL tibble via \code{\link[dplyr:tbl]{dplyr::tbl()}}).}
+
+\item{\code{table_name}}{Name of the table in the database. Can be a character
+string, or will be inferred from the \code{tbl} argument, if possible.}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+A new TblSqlSource object
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-get_db_type}{}}}
+\subsection{Method \code{get_db_type()}}{
+Get the database type
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$get_db_type()}\if{html}{\out{
}}
+}
+
+\subsection{Returns}{
+A string describing the database type (e.g., "DuckDB", "SQLite")
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-get_schema}{}}}
+\subsection{Method \code{get_schema()}}{
+Get schema information about the table
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$get_schema(categorical_threshold = 20)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{categorical_threshold}}{Maximum number of unique values for a text
+column to be considered categorical}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+A string containing schema information formatted for LLM prompts
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-execute_query}{}}}
+\subsection{Method \code{execute_query()}}{
+Execute a SQL query and return results
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$execute_query(query)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{query}}{SQL query string to execute}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+A data frame containing query results
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-test_query}{}}}
+\subsection{Method \code{test_query()}}{
+Test a SQL query by fetching only one row
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$test_query(query)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{query}}{SQL query string to test}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+A data frame containing one row of results (or empty if no matches)
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-prep_query}{}}}
+\subsection{Method \code{prep_query()}}{
+Prepare a generic \verb{SELECT * FROM ____} query to work with the SQL tibble
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$prep_query(query)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{query}}{SQL query as a string}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+A complete SQL query string
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-get_data}{}}}
+\subsection{Method \code{get_data()}}{
+Get the unfiltered data as a SQL tibble
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$get_data()}\if{html}{\out{
}}
+}
+
+\subsection{Returns}{
+A \code{\link[dbplyr:tbl_sql]{dbplyr::tbl_sql()}} containing the original, unfiltered data
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-cleanup}{}}}
+\subsection{Method \code{cleanup()}}{
+Clean up resources (close connections, etc.)
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$cleanup()}\if{html}{\out{
}}
+}
+
+\subsection{Returns}{
+NULL (invisibly)
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-TblSqlSource-clone}{}}}
+\subsection{Method \code{clone()}}{
+The objects of this class are cloneable with this method.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{TblSqlSource$clone(deep = FALSE)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{deep}}{Whether to make a deep clone.}
+}
+\if{html}{\out{
}}
+}
+}
+}
diff --git a/pkg-r/man/querychat-convenience.Rd b/pkg-r/man/querychat-convenience.Rd
index fd4410da..1be2eace 100644
--- a/pkg-r/man/querychat-convenience.Rd
+++ b/pkg-r/man/querychat-convenience.Rd
@@ -123,9 +123,9 @@ querychat_app(
# Chat with a database table (table_name required)
library(DBI)
-conn <- dbConnect(RSQLite::SQLite(), ":memory:")
-dbWriteTable(conn, "mtcars", mtcars)
-querychat_app(conn, "mtcars")
+con <- dbConnect(RSQLite::SQLite(), ":memory:")
+dbWriteTable(con, "mtcars", mtcars)
+querychat_app(con, "mtcars")
# Create QueryChat class object
qc <- querychat(mtcars)
diff --git a/pkg-r/pkgdown/_pkgdown.yml b/pkg-r/pkgdown/_pkgdown.yml
index 94cc6cc0..f4a6b7db 100644
--- a/pkg-r/pkgdown/_pkgdown.yml
+++ b/pkg-r/pkgdown/_pkgdown.yml
@@ -35,7 +35,7 @@ navbar:
right: [search, github, lightswitch]
components:
articles:
- text: Articles
+ text: Articles
menu:
- text: Models
href: articles/models.html
@@ -64,6 +64,7 @@ reference:
- DataSource
- DataFrameSource
- DBISource
+ - TblSqlSource
- title: Package
contents:
- querychat
diff --git a/pkg-r/tests/testthat/_snaps/DBISource.md b/pkg-r/tests/testthat/_snaps/DBISource.md
new file mode 100644
index 00000000..8559e656
--- /dev/null
+++ b/pkg-r/tests/testthat/_snaps/DBISource.md
@@ -0,0 +1,57 @@
+# DBISource$new() / errors with non-DBI connection
+
+ Code
+ DBISource$new(list(fake = "connection"), "test_table")
+ Condition
+ Error in `initialize()`:
+ ! `conn` must be a , not a list
+
+---
+
+ Code
+ DBISource$new(NULL, "test_table")
+ Condition
+ Error in `initialize()`:
+ ! `conn` must be a , not NULL
+
+---
+
+ Code
+ DBISource$new("not a connection", "test_table")
+ Condition
+ Error in `initialize()`:
+ ! `conn` must be a , not a string
+
+# DBISource$new() / errors with invalid table_name types
+
+ Code
+ DBISource$new(db$conn, 123)
+ Condition
+ Error in `initialize()`:
+ ! `table_name` must be a single character string or a `DBI::Id()` object
+
+---
+
+ Code
+ DBISource$new(db$conn, c("table1", "table2"))
+ Condition
+ Error in `initialize()`:
+ ! `table_name` must be a single character string or a `DBI::Id()` object
+
+---
+
+ Code
+ DBISource$new(db$conn, list(name = "table"))
+ Condition
+ Error in `initialize()`:
+ ! `table_name` must be a single character string or a `DBI::Id()` object
+
+# DBISource$new() / errors when table does not exist
+
+ Code
+ DBISource$new(db$conn, "non_existent_table")
+ Condition
+ Error in `initialize()`:
+ ! Table "`non_existent_table`" not found in database
+ i If you're using a table in a catalog or schema, pass a `DBI::Id()` object to `table_name`
+
diff --git a/pkg-r/tests/testthat/_snaps/DataFrameSource.md b/pkg-r/tests/testthat/_snaps/DataFrameSource.md
new file mode 100644
index 00000000..3d21372d
--- /dev/null
+++ b/pkg-r/tests/testthat/_snaps/DataFrameSource.md
@@ -0,0 +1,84 @@
+# DataFrameSource$new() / errors with non-data.frame input
+
+ Code
+ DataFrameSource$new(list(a = 1, b = 2), "test_table")
+ Condition
+ Error in `initialize()`:
+ ! `df` must be a data frame, not a list.
+
+---
+
+ Code
+ DataFrameSource$new(c(1, 2, 3), "test_table")
+ Condition
+ Error in `initialize()`:
+ ! `df` must be a data frame, not a double vector.
+
+---
+
+ Code
+ DataFrameSource$new(NULL, "test_table")
+ Condition
+ Error in `initialize()`:
+ ! `df` must be a data frame, not `NULL`.
+
+# DataFrameSource$new() / errors with invalid table names
+
+ Code
+ DataFrameSource$new(test_df, "123_invalid")
+ Condition
+ Error in `initialize()`:
+ ! `table_name` must be a valid SQL table name
+ i Table names must begin with a letter and contain only letters, numbers, and underscores
+ x You provided: "123_invalid"
+ Code
+ DataFrameSource$new(test_df, "table-name")
+ Condition
+ Error in `initialize()`:
+ ! `table_name` must be a valid SQL table name
+ i Table names must begin with a letter and contain only letters, numbers, and underscores
+ x You provided: "table-name"
+ Code
+ DataFrameSource$new(test_df, "table name")
+ Condition
+ Error in `initialize()`:
+ ! `table_name` must be a valid SQL table name
+ i Table names must begin with a letter and contain only letters, numbers, and underscores
+ x You provided: "table name"
+ Code
+ DataFrameSource$new(test_df, "")
+ Condition
+ Error in `initialize()`:
+ ! `table_name` must be a valid SQL table name
+ i Table names must begin with a letter and contain only letters, numbers, and underscores
+ x You provided: ""
+ Code
+ DataFrameSource$new(test_df, NULL)
+ Condition
+ Error in `initialize()`:
+ ! `table_name` must be a single string, not `NULL`.
+
+# DataFrameSource engine parameter / engine parameter validation / errors on invalid engine name
+
+ Code
+ DataFrameSource$new(new_test_df(), "test_table", engine = "postgres")
+ Condition
+ Error in `initialize()`:
+ ! `engine` must be one of "duckdb" or "sqlite", not "postgres".
+
+---
+
+ Code
+ DataFrameSource$new(new_test_df(), "test_table", engine = "invalid")
+ Condition
+ Error in `initialize()`:
+ ! `engine` must be one of "duckdb" or "sqlite", not "invalid".
+
+---
+
+ Code
+ DataFrameSource$new(new_test_df(), "test_table", engine = "")
+ Condition
+ Error in `initialize()`:
+ ! `engine` must be one of "duckdb" or "sqlite", not "".
+
diff --git a/pkg-r/tests/testthat/_snaps/DataSource.md b/pkg-r/tests/testthat/_snaps/DataSource.md
index 957d7d80..8492aa4b 100644
--- a/pkg-r/tests/testthat/_snaps/DataSource.md
+++ b/pkg-r/tests/testthat/_snaps/DataSource.md
@@ -46,147 +46,6 @@
Error in `base_source$cleanup()`:
! `cleanup()` must be implemented by subclass
-# DataFrameSource$new() / errors with non-data.frame input
-
- Code
- DataFrameSource$new(list(a = 1, b = 2), "test_table")
- Condition
- Error in `initialize()`:
- ! `df` must be a data frame, not a list.
-
----
-
- Code
- DataFrameSource$new(c(1, 2, 3), "test_table")
- Condition
- Error in `initialize()`:
- ! `df` must be a data frame, not a double vector.
-
----
-
- Code
- DataFrameSource$new(NULL, "test_table")
- Condition
- Error in `initialize()`:
- ! `df` must be a data frame, not `NULL`.
-
-# DataFrameSource$new() / errors with invalid table names
-
- Code
- DataFrameSource$new(test_df, "123_invalid")
- Condition
- Error in `initialize()`:
- ! `table_name` must be a valid SQL table name
- i Table names must begin with a letter and contain only letters, numbers, and underscores
- x You provided: "123_invalid"
- Code
- DataFrameSource$new(test_df, "table-name")
- Condition
- Error in `initialize()`:
- ! `table_name` must be a valid SQL table name
- i Table names must begin with a letter and contain only letters, numbers, and underscores
- x You provided: "table-name"
- Code
- DataFrameSource$new(test_df, "table name")
- Condition
- Error in `initialize()`:
- ! `table_name` must be a valid SQL table name
- i Table names must begin with a letter and contain only letters, numbers, and underscores
- x You provided: "table name"
- Code
- DataFrameSource$new(test_df, "")
- Condition
- Error in `initialize()`:
- ! `table_name` must be a valid SQL table name
- i Table names must begin with a letter and contain only letters, numbers, and underscores
- x You provided: ""
- Code
- DataFrameSource$new(test_df, NULL)
- Condition
- Error in `initialize()`:
- ! `table_name` must be a single string, not `NULL`.
-
-# DataFrameSource engine parameter / engine parameter validation / errors on invalid engine name
-
- Code
- DataFrameSource$new(new_test_df(), "test_table", engine = "postgres")
- Condition
- Error in `initialize()`:
- ! `engine` must be one of "duckdb" or "sqlite", not "postgres".
-
----
-
- Code
- DataFrameSource$new(new_test_df(), "test_table", engine = "invalid")
- Condition
- Error in `initialize()`:
- ! `engine` must be one of "duckdb" or "sqlite", not "invalid".
-
----
-
- Code
- DataFrameSource$new(new_test_df(), "test_table", engine = "")
- Condition
- Error in `initialize()`:
- ! `engine` must be one of "duckdb" or "sqlite", not "".
-
-# DBISource$new() / errors with non-DBI connection
-
- Code
- DBISource$new(list(fake = "connection"), "test_table")
- Condition
- Error in `initialize()`:
- ! `conn` must be a , not a list
-
----
-
- Code
- DBISource$new(NULL, "test_table")
- Condition
- Error in `initialize()`:
- ! `conn` must be a , not NULL
-
----
-
- Code
- DBISource$new("not a connection", "test_table")
- Condition
- Error in `initialize()`:
- ! `conn` must be a , not a string
-
-# DBISource$new() / errors with invalid table_name types
-
- Code
- DBISource$new(db$conn, 123)
- Condition
- Error in `initialize()`:
- ! `table_name` must be a single character string or a `DBI::Id()` object
-
----
-
- Code
- DBISource$new(db$conn, c("table1", "table2"))
- Condition
- Error in `initialize()`:
- ! `table_name` must be a single character string or a `DBI::Id()` object
-
----
-
- Code
- DBISource$new(db$conn, list(name = "table"))
- Condition
- Error in `initialize()`:
- ! `table_name` must be a single character string or a `DBI::Id()` object
-
-# DBISource$new() / errors when table does not exist
-
- Code
- DBISource$new(db$conn, "non_existent_table")
- Condition
- Error in `initialize()`:
- ! Table "`non_existent_table`" not found in database
- i If you're using a table in a catalog or schema, pass a `DBI::Id()` object to `table_name`
-
# test_query() column validation / provides helpful error message listing missing columns
Code
diff --git a/pkg-r/tests/testthat/_snaps/TblSqlSource.md b/pkg-r/tests/testthat/_snaps/TblSqlSource.md
new file mode 100644
index 00000000..520e1a6a
--- /dev/null
+++ b/pkg-r/tests/testthat/_snaps/TblSqlSource.md
@@ -0,0 +1,16 @@
+# TblSqlSource$new() / errors with non-tbl_sql input
+
+ Code
+ TblSqlSource$new(data.frame(a = 1))
+ Condition
+ Error in `initialize()`:
+ ! `tbl` must be a SQL tibble connected to a database, not a data frame
+
+---
+
+ Code
+ TblSqlSource$new(list(a = 1, b = 2))
+ Condition
+ Error in `initialize()`:
+ ! `tbl` must be a SQL tibble connected to a database, not a list
+
diff --git a/pkg-r/tests/testthat/helper-fixtures.R b/pkg-r/tests/testthat/helper-fixtures.R
index c98e0643..d800b010 100644
--- a/pkg-r/tests/testthat/helper-fixtures.R
+++ b/pkg-r/tests/testthat/helper-fixtures.R
@@ -101,6 +101,61 @@ local_querychat <- function(
qc
}
+# Create a TblSqlSource with DuckDB and automatic cleanup
+local_tbl_sql_source <- function(
+ data = new_test_df(),
+ table_name = "test_table",
+ tbl_transform = identity,
+ env = parent.frame()
+) {
+ skip_if_not_installed("duckdb")
+ skip_if_not_installed("dbplyr")
+ skip_if_not_installed("dplyr")
+
+ conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:")
+ withr::defer(DBI::dbDisconnect(conn, shutdown = TRUE), envir = env)
+
+ DBI::dbWriteTable(conn, table_name, data, overwrite = TRUE)
+ tbl <- dplyr::tbl(conn, table_name)
+ tbl <- tbl_transform(tbl)
+
+ TblSqlSource$new(tbl, table_name)
+}
+
+# Create a DuckDB connection with multiple tables for JOIN tests
+local_duckdb_multi_table <- function(
+ env = parent.frame()
+) {
+ skip_if_not_installed("duckdb")
+
+ conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:")
+ withr::defer(DBI::dbDisconnect(conn, shutdown = TRUE), envir = env)
+
+ # Table A with id and name
+ DBI::dbWriteTable(
+ conn,
+ "table_a",
+ data.frame(
+ id = 1:3,
+ name = c("Alice", "Bob", "Carol"),
+ stringsAsFactors = FALSE
+ )
+ )
+
+ # Table B with id and value
+ DBI::dbWriteTable(
+ conn,
+ "table_b",
+ data.frame(
+ id = 1:3,
+ value = c(100, 200, 300),
+ stringsAsFactors = FALSE
+ )
+ )
+
+ conn
+}
+
mock_ellmer_chat_client <- function(
public = list(),
private = list(),
diff --git a/pkg-r/tests/testthat/test-DBISource.R b/pkg-r/tests/testthat/test-DBISource.R
new file mode 100644
index 00000000..69ff7be2
--- /dev/null
+++ b/pkg-r/tests/testthat/test-DBISource.R
@@ -0,0 +1,102 @@
+describe("DBISource$new()", {
+ it("creates proper R6 object for DBISource", {
+ db <- local_sqlite_connection(new_users_df(), "users")
+
+ db_source <- DBISource$new(db$conn, "users")
+ expect_s3_class(db_source, "DBISource")
+ expect_s3_class(db_source, "DataSource")
+ expect_equal(db_source$table_name, "users")
+ })
+
+ it("errors with non-DBI connection", {
+ expect_snapshot(error = TRUE, {
+ DBISource$new(list(fake = "connection"), "test_table")
+ })
+
+ expect_snapshot(error = TRUE, {
+ DBISource$new(NULL, "test_table")
+ })
+
+ expect_snapshot(error = TRUE, {
+ DBISource$new("not a connection", "test_table")
+ })
+ })
+
+ it("errors with invalid table_name types", {
+ db <- local_sqlite_connection(new_test_df())
+
+ expect_snapshot(error = TRUE, {
+ DBISource$new(db$conn, 123)
+ })
+
+ expect_snapshot(error = TRUE, {
+ DBISource$new(db$conn, c("table1", "table2"))
+ })
+
+ expect_snapshot(error = TRUE, {
+ DBISource$new(db$conn, list(name = "table"))
+ })
+ })
+
+ it("errors when table does not exist", {
+ db <- local_sqlite_connection(new_test_df(), "existing_table")
+
+ expect_snapshot(error = TRUE, {
+ DBISource$new(db$conn, "non_existent_table")
+ })
+ })
+})
+
+describe("DBISource$test_query()", {
+ test_df <- new_users_df()
+ db <- local_sqlite_connection(test_df, "test_table")
+ dbi_source <- DBISource$new(db$conn, "test_table")
+
+ it("correctly retrieves one row of data", {
+ result <- dbi_source$test_query("SELECT * FROM test_table")
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 1)
+ expect_equal(result$id, 1)
+
+ result <- dbi_source$test_query("SELECT * FROM test_table WHERE age > 29")
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 1)
+ expect_equal(result$age, 30)
+
+ result <- dbi_source$test_query(
+ "SELECT * FROM test_table ORDER BY age DESC"
+ )
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 1)
+ expect_equal(result$age, 35)
+
+ result <- dbi_source$test_query(
+ "SELECT * FROM test_table WHERE age > 100"
+ )
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 0)
+ })
+
+ it("handles errors correctly", {
+ expect_error(dbi_source$test_query("SELECT * WRONG SYNTAX"))
+
+ expect_error(dbi_source$test_query("SELECT * FROM non_existent_table"))
+
+ expect_error(dbi_source$test_query(
+ "SELECT non_existent_column FROM test_table"
+ ))
+ })
+
+ it("works with different data types", {
+ db <- local_sqlite_connection(new_types_df(), "types_table")
+ dbi_source <- DBISource$new(db$conn, "types_table")
+
+ result <- dbi_source$test_query("SELECT * FROM types_table")
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 1)
+ expect_type(result$text_col, "character")
+ expect_type(result$num_col, "double")
+ expect_type(result$int_col, "integer")
+ expect_type(result$bool_col, "integer")
+ })
+})
diff --git a/pkg-r/tests/testthat/test-DataFrameSource.R b/pkg-r/tests/testthat/test-DataFrameSource.R
new file mode 100644
index 00000000..c2cc776d
--- /dev/null
+++ b/pkg-r/tests/testthat/test-DataFrameSource.R
@@ -0,0 +1,221 @@
+describe("DataFrameSource$new()", {
+ skip_if_no_dataframe_engine()
+
+ it("creates proper R6 object for DataFrameSource", {
+ test_df <- new_test_df()
+
+ source <- DataFrameSource$new(test_df, "test_table")
+ withr::defer(source$cleanup())
+
+ expect_s3_class(source, "DataFrameSource")
+ expect_s3_class(source, "DataSource")
+ expect_equal(source$table_name, "test_table")
+ })
+
+ it("errors with non-data.frame input", {
+ expect_snapshot(
+ error = TRUE,
+ DataFrameSource$new(list(a = 1, b = 2), "test_table")
+ )
+ expect_snapshot(error = TRUE, DataFrameSource$new(c(1, 2, 3), "test_table"))
+ expect_snapshot(error = TRUE, DataFrameSource$new(NULL, "test_table"))
+ })
+
+ it("errors with invalid table names", {
+ test_df <- new_test_df()
+
+ expect_snapshot(error = TRUE, {
+ DataFrameSource$new(test_df, "123_invalid")
+ DataFrameSource$new(test_df, "table-name")
+ DataFrameSource$new(test_df, "table name")
+ DataFrameSource$new(test_df, "")
+ DataFrameSource$new(test_df, NULL)
+ })
+ })
+})
+
+describe("DataFrameSource engine parameter", {
+ describe("with duckdb engine", {
+ skip_if_not_installed("duckdb")
+
+ it("creates connection with duckdb backend", {
+ df_source <- local_data_frame_source(new_test_df(), engine = "duckdb")
+
+ expect_s3_class(df_source, "DataFrameSource")
+ expect_s3_class(df_source, "DBISource")
+ expect_equal(df_source$table_name, "test_table")
+ expect_equal(df_source$get_db_type(), "DuckDB")
+ })
+
+ it("executes queries correctly", {
+ test_df <- new_test_df()
+ df_source <- local_data_frame_source(test_df, engine = "duckdb")
+
+ # Test filtering
+ result <- df_source$execute_query(
+ "SELECT * FROM test_table WHERE value > 25"
+ )
+ expect_equal(nrow(result), 3)
+ expect_equal(result$value, c(30, 40, 50))
+
+ # Test get_data
+ all_data <- df_source$get_data()
+ expect_equal(all_data, test_df)
+
+ # Test test_query
+ one_row <- df_source$test_query("SELECT * FROM test_table")
+ expect_equal(nrow(one_row), 1)
+ })
+
+ it("returns correct schema", {
+ df_source <- local_data_frame_source(
+ new_mixed_types_df(),
+ engine = "duckdb"
+ )
+ schema <- df_source$get_schema()
+
+ expect_type(schema, "character")
+ expect_match(schema, "Table: test_table")
+ expect_match(schema, "id \\(INTEGER\\)")
+ expect_match(schema, "name \\(TEXT\\)")
+ expect_match(schema, "active \\(BOOLEAN\\)")
+ })
+ })
+
+ describe("with sqlite engine", {
+ skip_if_not_installed("RSQLite")
+
+ it("creates connection with sqlite backend", {
+ df_source <- local_data_frame_source(new_test_df(), engine = "sqlite")
+
+ expect_s3_class(df_source, "DataFrameSource")
+ expect_s3_class(df_source, "DBISource")
+ expect_equal(df_source$table_name, "test_table")
+ expect_equal(df_source$get_db_type(), "SQLite")
+ })
+
+ it("executes queries correctly", {
+ test_df <- new_test_df()
+ df_source <- local_data_frame_source(test_df, engine = "sqlite")
+
+ # Test filtering
+ result <- df_source$execute_query(
+ "SELECT * FROM test_table WHERE value > 25"
+ )
+ expect_equal(nrow(result), 3)
+ expect_equal(result$value, c(30, 40, 50))
+
+ # Test get_data
+ all_data <- df_source$get_data()
+ expect_equal(all_data, test_df)
+
+ # Test test_query
+ one_row <- df_source$test_query("SELECT * FROM test_table")
+ expect_equal(nrow(one_row), 1)
+ })
+
+ it("returns correct schema", {
+ df_source <- local_data_frame_source(
+ new_mixed_types_df(),
+ engine = "sqlite"
+ )
+ schema <- df_source$get_schema()
+
+ expect_type(schema, "character")
+ expect_match(schema, "Table:")
+ expect_match(schema, "test_table")
+ expect_match(schema, "id \\(INTEGER\\)")
+ expect_match(schema, "name \\(TEXT\\)")
+ # SQLite stores booleans as INTEGER (0/1)
+ expect_match(schema, "active \\(INTEGER\\)")
+ })
+ })
+
+ describe("engine parameter validation", {
+ it("is case-insensitive", {
+ skip_if_not_installed("duckdb")
+ skip_if_not_installed("RSQLite")
+
+ # Test various case combinations
+ df1 <- local_data_frame_source(new_test_df(), engine = "DUCKDB")
+ expect_equal(df1$get_db_type(), "DuckDB")
+
+ df2 <- local_data_frame_source(new_test_df(), engine = "DuckDb")
+ expect_equal(df2$get_db_type(), "DuckDB")
+
+ df3 <- local_data_frame_source(new_test_df(), engine = "SQLite")
+ expect_equal(df3$get_db_type(), "SQLite")
+
+ df4 <- local_data_frame_source(new_test_df(), engine = "SQLITE")
+ expect_equal(df4$get_db_type(), "SQLite")
+ })
+
+ it("errors on invalid engine name", {
+ expect_snapshot(error = TRUE, {
+ DataFrameSource$new(new_test_df(), "test_table", engine = "postgres")
+ })
+
+ expect_snapshot(error = TRUE, {
+ DataFrameSource$new(new_test_df(), "test_table", engine = "invalid")
+ })
+
+ expect_snapshot(error = TRUE, {
+ DataFrameSource$new(new_test_df(), "test_table", engine = "")
+ })
+ })
+
+ it("respects getOption('querychat.DataFrameSource.engine')", {
+ skip_if_not_installed("duckdb")
+ skip_if_not_installed("RSQLite")
+
+ # Test default (duckdb)
+ withr::local_options(querychat.DataFrameSource.engine = NULL)
+ df1 <- DataFrameSource$new(new_test_df(), "test_table")
+ withr::defer(df1$cleanup())
+ expect_equal(df1$get_db_type(), "DuckDB")
+
+ # Test option set to sqlite
+ withr::local_options(querychat.DataFrameSource.engine = "sqlite")
+ df2 <- DataFrameSource$new(new_test_df(), "test_table")
+ withr::defer(df2$cleanup())
+ expect_equal(df2$get_db_type(), "SQLite")
+
+ # Test explicit parameter overrides option
+ withr::local_options(querychat.DataFrameSource.engine = "sqlite")
+ df3 <- local_data_frame_source(new_test_df(), engine = "duckdb")
+ expect_equal(df3$get_db_type(), "DuckDB")
+ })
+ })
+})
+
+describe("DataFrameSource$test_query()", {
+ skip_if_no_dataframe_engine()
+
+ test_df <- new_users_df()
+ df_source <- local_data_frame_source(test_df, "test_table")
+
+ it("correctly retrieves one row of data", {
+ result <- df_source$test_query("SELECT * FROM test_table")
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 1)
+ expect_equal(result$id, 1)
+
+ result <- df_source$test_query("SELECT * FROM test_table WHERE age > 29")
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 1)
+ expect_equal(result$age, 30)
+
+ result <- df_source$test_query(
+ "SELECT * FROM test_table ORDER BY age DESC"
+ )
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 1)
+ expect_equal(result$age, 35)
+
+ result <- df_source$test_query(
+ "SELECT * FROM test_table WHERE age > 100"
+ )
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 0)
+ })
+})
diff --git a/pkg-r/tests/testthat/test-DataSource.R b/pkg-r/tests/testthat/test-DataSource.R
index 9a2333d2..267743e5 100644
--- a/pkg-r/tests/testthat/test-DataSource.R
+++ b/pkg-r/tests/testthat/test-DataSource.R
@@ -29,244 +29,6 @@ describe("DataSource base class", {
})
})
-describe("DataFrameSource$new()", {
- skip_if_no_dataframe_engine()
-
- it("creates proper R6 object for DataFrameSource", {
- test_df <- new_test_df()
-
- source <- DataFrameSource$new(test_df, "test_table")
- withr::defer(source$cleanup())
-
- expect_s3_class(source, "DataFrameSource")
- expect_s3_class(source, "DataSource")
- expect_equal(source$table_name, "test_table")
- })
-
- it("errors with non-data.frame input", {
- expect_snapshot(
- error = TRUE,
- DataFrameSource$new(list(a = 1, b = 2), "test_table")
- )
- expect_snapshot(error = TRUE, DataFrameSource$new(c(1, 2, 3), "test_table"))
- expect_snapshot(error = TRUE, DataFrameSource$new(NULL, "test_table"))
- })
-
- it("errors with invalid table names", {
- test_df <- new_test_df()
-
- expect_snapshot(error = TRUE, {
- DataFrameSource$new(test_df, "123_invalid")
- DataFrameSource$new(test_df, "table-name")
- DataFrameSource$new(test_df, "table name")
- DataFrameSource$new(test_df, "")
- DataFrameSource$new(test_df, NULL)
- })
- })
-})
-
-describe("DataFrameSource engine parameter", {
- describe("with duckdb engine", {
- skip_if_not_installed("duckdb")
-
- it("creates connection with duckdb backend", {
- df_source <- local_data_frame_source(new_test_df(), engine = "duckdb")
-
- expect_s3_class(df_source, "DataFrameSource")
- expect_s3_class(df_source, "DBISource")
- expect_equal(df_source$table_name, "test_table")
- expect_equal(df_source$get_db_type(), "DuckDB")
- })
-
- it("executes queries correctly", {
- test_df <- new_test_df()
- df_source <- local_data_frame_source(test_df, engine = "duckdb")
-
- # Test filtering
- result <- df_source$execute_query(
- "SELECT * FROM test_table WHERE value > 25"
- )
- expect_equal(nrow(result), 3)
- expect_equal(result$value, c(30, 40, 50))
-
- # Test get_data
- all_data <- df_source$get_data()
- expect_equal(all_data, test_df)
-
- # Test test_query
- one_row <- df_source$test_query("SELECT * FROM test_table")
- expect_equal(nrow(one_row), 1)
- })
-
- it("returns correct schema", {
- df_source <- local_data_frame_source(
- new_mixed_types_df(),
- engine = "duckdb"
- )
- schema <- df_source$get_schema()
-
- expect_type(schema, "character")
- expect_match(schema, "Table: test_table")
- expect_match(schema, "id \\(INTEGER\\)")
- expect_match(schema, "name \\(TEXT\\)")
- expect_match(schema, "active \\(BOOLEAN\\)")
- })
- })
-
- describe("with sqlite engine", {
- skip_if_not_installed("RSQLite")
-
- it("creates connection with sqlite backend", {
- df_source <- local_data_frame_source(new_test_df(), engine = "sqlite")
-
- expect_s3_class(df_source, "DataFrameSource")
- expect_s3_class(df_source, "DBISource")
- expect_equal(df_source$table_name, "test_table")
- expect_equal(df_source$get_db_type(), "SQLite")
- })
-
- it("executes queries correctly", {
- test_df <- new_test_df()
- df_source <- local_data_frame_source(test_df, engine = "sqlite")
-
- # Test filtering
- result <- df_source$execute_query(
- "SELECT * FROM test_table WHERE value > 25"
- )
- expect_equal(nrow(result), 3)
- expect_equal(result$value, c(30, 40, 50))
-
- # Test get_data
- all_data <- df_source$get_data()
- expect_equal(all_data, test_df)
-
- # Test test_query
- one_row <- df_source$test_query("SELECT * FROM test_table")
- expect_equal(nrow(one_row), 1)
- })
-
- it("returns correct schema", {
- df_source <- local_data_frame_source(
- new_mixed_types_df(),
- engine = "sqlite"
- )
- schema <- df_source$get_schema()
-
- expect_type(schema, "character")
- expect_match(schema, "Table:")
- expect_match(schema, "test_table")
- expect_match(schema, "id \\(INTEGER\\)")
- expect_match(schema, "name \\(TEXT\\)")
- # SQLite stores booleans as INTEGER (0/1)
- expect_match(schema, "active \\(INTEGER\\)")
- })
- })
-
- describe("engine parameter validation", {
- it("is case-insensitive", {
- skip_if_not_installed("duckdb")
- skip_if_not_installed("RSQLite")
-
- # Test various case combinations
- df1 <- local_data_frame_source(new_test_df(), engine = "DUCKDB")
- expect_equal(df1$get_db_type(), "DuckDB")
-
- df2 <- local_data_frame_source(new_test_df(), engine = "DuckDb")
- expect_equal(df2$get_db_type(), "DuckDB")
-
- df3 <- local_data_frame_source(new_test_df(), engine = "SQLite")
- expect_equal(df3$get_db_type(), "SQLite")
-
- df4 <- local_data_frame_source(new_test_df(), engine = "SQLITE")
- expect_equal(df4$get_db_type(), "SQLite")
- })
-
- it("errors on invalid engine name", {
- expect_snapshot(error = TRUE, {
- DataFrameSource$new(new_test_df(), "test_table", engine = "postgres")
- })
-
- expect_snapshot(error = TRUE, {
- DataFrameSource$new(new_test_df(), "test_table", engine = "invalid")
- })
-
- expect_snapshot(error = TRUE, {
- DataFrameSource$new(new_test_df(), "test_table", engine = "")
- })
- })
-
- it("respects getOption('querychat.DataFrameSource.engine')", {
- skip_if_not_installed("duckdb")
- skip_if_not_installed("RSQLite")
-
- # Test default (duckdb)
- withr::local_options(querychat.DataFrameSource.engine = NULL)
- df1 <- DataFrameSource$new(new_test_df(), "test_table")
- withr::defer(df1$cleanup())
- expect_equal(df1$get_db_type(), "DuckDB")
-
- # Test option set to sqlite
- withr::local_options(querychat.DataFrameSource.engine = "sqlite")
- df2 <- DataFrameSource$new(new_test_df(), "test_table")
- withr::defer(df2$cleanup())
- expect_equal(df2$get_db_type(), "SQLite")
-
- # Test explicit parameter overrides option
- withr::local_options(querychat.DataFrameSource.engine = "sqlite")
- df3 <- local_data_frame_source(new_test_df(), engine = "duckdb")
- expect_equal(df3$get_db_type(), "DuckDB")
- })
- })
-})
-
-describe("DBISource$new()", {
- it("creates proper R6 object for DBISource", {
- db <- local_sqlite_connection(new_users_df(), "users")
-
- db_source <- DBISource$new(db$conn, "users")
- expect_s3_class(db_source, "DBISource")
- expect_s3_class(db_source, "DataSource")
- expect_equal(db_source$table_name, "users")
- })
-
- it("errors with non-DBI connection", {
- expect_snapshot(error = TRUE, {
- DBISource$new(list(fake = "connection"), "test_table")
- })
-
- expect_snapshot(error = TRUE, {
- DBISource$new(NULL, "test_table")
- })
-
- expect_snapshot(error = TRUE, {
- DBISource$new("not a connection", "test_table")
- })
- })
-
- it("errors with invalid table_name types", {
- db <- local_sqlite_connection(new_test_df())
-
- expect_snapshot(error = TRUE, {
- DBISource$new(db$conn, 123)
- })
-
- expect_snapshot(error = TRUE, {
- DBISource$new(db$conn, c("table1", "table2"))
- })
-
- expect_snapshot(error = TRUE, {
- DBISource$new(db$conn, list(name = "table"))
- })
- })
-
- it("errors when table does not exist", {
- db <- local_sqlite_connection(new_test_df(), "existing_table")
-
- expect_snapshot(error = TRUE, {
- DBISource$new(db$conn, "non_existent_table")
- })
- })
-})
describe("DataSource$get_schema()", {
it("returns proper schema for DataFrameSource", {
@@ -549,91 +311,6 @@ describe("DataSource$execute_query()", {
})
})
-describe("DBISource$test_query()", {
- test_df <- new_users_df()
- db <- local_sqlite_connection(test_df, "test_table")
- dbi_source <- DBISource$new(db$conn, "test_table")
-
- it("correctly retrieves one row of data", {
- result <- dbi_source$test_query("SELECT * FROM test_table")
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 1)
- expect_equal(result$id, 1)
-
- result <- dbi_source$test_query("SELECT * FROM test_table WHERE age > 29")
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 1)
- expect_equal(result$age, 30)
-
- result <- dbi_source$test_query(
- "SELECT * FROM test_table ORDER BY age DESC"
- )
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 1)
- expect_equal(result$age, 35)
-
- result <- dbi_source$test_query(
- "SELECT * FROM test_table WHERE age > 100"
- )
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 0)
- })
-
- it("handles errors correctly", {
- expect_error(dbi_source$test_query("SELECT * WRONG SYNTAX"))
-
- expect_error(dbi_source$test_query("SELECT * FROM non_existent_table"))
-
- expect_error(dbi_source$test_query(
- "SELECT non_existent_column FROM test_table"
- ))
- })
-
- it("works with different data types", {
- db <- local_sqlite_connection(new_types_df(), "types_table")
- dbi_source <- DBISource$new(db$conn, "types_table")
-
- result <- dbi_source$test_query("SELECT * FROM types_table")
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 1)
- expect_type(result$text_col, "character")
- expect_type(result$num_col, "double")
- expect_type(result$int_col, "integer")
- expect_type(result$bool_col, "integer")
- })
-})
-
-describe("DataFrameSource$test_query()", {
- skip_if_no_dataframe_engine()
-
- test_df <- new_users_df()
- df_source <- local_data_frame_source(test_df, "test_table")
-
- it("correctly retrieves one row of data", {
- result <- df_source$test_query("SELECT * FROM test_table")
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 1)
- expect_equal(result$id, 1)
-
- result <- df_source$test_query("SELECT * FROM test_table WHERE age > 29")
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 1)
- expect_equal(result$age, 30)
-
- result <- df_source$test_query(
- "SELECT * FROM test_table ORDER BY age DESC"
- )
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 1)
- expect_equal(result$age, 35)
-
- result <- df_source$test_query(
- "SELECT * FROM test_table WHERE age > 100"
- )
- expect_s3_class(result, "data.frame")
- expect_equal(nrow(result), 0)
- })
-})
describe("test_query() column validation", {
skip_if_no_dataframe_engine()
diff --git a/pkg-r/tests/testthat/test-QueryChatSystemPrompt.R b/pkg-r/tests/testthat/test-QueryChatSystemPrompt.R
index 9eb964b1..4339f077 100644
--- a/pkg-r/tests/testthat/test-QueryChatSystemPrompt.R
+++ b/pkg-r/tests/testthat/test-QueryChatSystemPrompt.R
@@ -134,6 +134,7 @@ describe("QueryChatSystemPrompt$render()", {
template <- paste(
"{{#has_tool_update}}update enabled{{/has_tool_update}}",
"{{#has_tool_query}}query enabled{{/has_tool_query}}",
+ "{{#include_query_guidelines}}query guidelines{{/include_query_guidelines}}",
sep = "\n"
)
@@ -146,6 +147,7 @@ describe("QueryChatSystemPrompt$render()", {
expect_true(grepl("update enabled", result))
expect_true(grepl("query enabled", result))
+ expect_true(grepl("query guidelines", result))
})
it("renders with query only", {
@@ -156,6 +158,7 @@ describe("QueryChatSystemPrompt$render()", {
template <- paste(
"{{#has_tool_update}}update enabled{{/has_tool_update}}",
"{{#has_tool_query}}query enabled{{/has_tool_query}}",
+ "{{#include_query_guidelines}}query guidelines{{/include_query_guidelines}}",
sep = "\n"
)
@@ -168,6 +171,7 @@ describe("QueryChatSystemPrompt$render()", {
expect_false(grepl("update enabled", result))
expect_true(grepl("query enabled", result))
+ expect_true(grepl("query guidelines", result))
})
it("renders with update only", {
@@ -178,6 +182,7 @@ describe("QueryChatSystemPrompt$render()", {
template <- paste(
"{{#has_tool_update}}update enabled{{/has_tool_update}}",
"{{#has_tool_query}}query enabled{{/has_tool_query}}",
+ "{{#include_query_guidelines}}query guidelines{{/include_query_guidelines}}",
sep = "\n"
)
@@ -190,6 +195,7 @@ describe("QueryChatSystemPrompt$render()", {
expect_true(grepl("update enabled", result))
expect_false(grepl("query enabled", result))
+ expect_true(grepl("query guidelines", result))
})
it("renders with NULL tools", {
@@ -200,6 +206,7 @@ describe("QueryChatSystemPrompt$render()", {
template <- paste(
"{{#has_tool_update}}update enabled{{/has_tool_update}}",
"{{#has_tool_query}}query enabled{{/has_tool_query}}",
+ "{{#include_query_guidelines}}query guidelines{{/include_query_guidelines}}",
"Always shown",
sep = "\n"
)
@@ -213,6 +220,7 @@ describe("QueryChatSystemPrompt$render()", {
expect_false(grepl("update enabled", result))
expect_false(grepl("query enabled", result))
+ expect_false(grepl("query guidelines", result))
expect_true(grepl("Always shown", result))
})
diff --git a/pkg-r/tests/testthat/test-TblSqlSource.R b/pkg-r/tests/testthat/test-TblSqlSource.R
new file mode 100644
index 00000000..d1abcedf
--- /dev/null
+++ b/pkg-r/tests/testthat/test-TblSqlSource.R
@@ -0,0 +1,388 @@
+describe("TblSqlSource$new()", {
+ it("creates proper R6 object for TblSqlSource", {
+ source <- local_tbl_sql_source()
+
+ expect_s3_class(source, "TblSqlSource")
+ expect_s3_class(source, "DBISource")
+ expect_s3_class(source, "DataSource")
+ expect_equal(source$table_name, "test_table")
+ expect_equal(source$get_db_type(), "DuckDB")
+ })
+
+ it("errors with non-tbl_sql input", {
+ skip_if_not_installed("duckdb")
+ skip_if_not_installed("dbplyr")
+
+ expect_snapshot(error = TRUE, {
+ TblSqlSource$new(data.frame(a = 1))
+ })
+
+ expect_snapshot(error = TRUE, {
+ TblSqlSource$new(list(a = 1, b = 2))
+ })
+ })
+
+ it("returns lazy tibble from execute_query()", {
+ source <- local_tbl_sql_source()
+
+ result <- source$execute_query("SELECT * FROM test_table WHERE value > 25")
+ expect_s3_class(result, "tbl_sql")
+ expect_s3_class(result, "tbl_lazy")
+
+ # Collect to verify data
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 3)
+ expect_equal(collected$value, c(30, 40, 50))
+ })
+
+ it("returns data frame from test_query()", {
+ source <- local_tbl_sql_source()
+
+ result <- source$test_query("SELECT * FROM test_table")
+ expect_s3_class(result, "data.frame")
+ expect_equal(nrow(result), 1)
+ })
+
+ it("returns lazy tibble from get_data()", {
+ source <- local_tbl_sql_source()
+
+ result <- source$get_data()
+ expect_s3_class(result, "tbl_sql")
+ expect_s3_class(result, "tbl_lazy")
+ })
+})
+
+describe("TblSqlSource with transformed tbl (CTE mode)", {
+ it("works with filtered tbl", {
+ source <- local_tbl_sql_source(
+ tbl_transform = function(tbl) dplyr::filter(tbl, value > 20)
+ )
+
+ # CTE should be used since tbl is transformed
+ result <- source$execute_query("SELECT * FROM test_table")
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 3)
+ expect_true(all(collected$value > 20))
+ })
+
+ it("works with selected columns tbl", {
+ source <- local_tbl_sql_source(
+ tbl_transform = function(tbl) dplyr::select(tbl, id, name)
+ )
+
+ result <- source$execute_query("SELECT * FROM test_table")
+ collected <- dplyr::collect(result)
+ expect_equal(names(collected), c("id", "name"))
+ })
+})
+
+describe("TblSqlSource edge cases - Category A: Structural Violations", {
+ # Note: TblSqlSource uses dplyr::tbl(conn, dplyr::sql(query)) which wraps
+
+ # the user's query as a subquery. This means some SQL constructs that work
+ # in standalone queries will fail when wrapped.
+
+ it("errors on trailing semicolon (subquery wrapping)", {
+ source <- local_tbl_sql_source()
+
+ # Semicolons inside subqueries cause syntax errors in DuckDB
+ # The query gets wrapped as: SELECT * FROM (SELECT * FROM test_table;) q01
+ expect_error(
+ dplyr::collect(source$execute_query("SELECT * FROM test_table;")),
+ regexp = "syntax error"
+ )
+ })
+
+ it("errors on trailing semicolon in CTE mode", {
+ source <- local_tbl_sql_source(
+ tbl_transform = function(tbl) dplyr::filter(tbl, value > 10)
+ )
+
+ # Same issue with CTE wrapping
+ expect_error(
+ dplyr::collect(source$execute_query("SELECT * FROM test_table;")),
+ regexp = "syntax error"
+ )
+ })
+
+ it("errors on multiple trailing semicolons", {
+ source <- local_tbl_sql_source()
+
+ expect_error(
+ dplyr::collect(source$execute_query("SELECT * FROM test_table;;;")),
+ regexp = "syntax error"
+ )
+ })
+
+ it("errors on multiple statements", {
+ source <- local_tbl_sql_source()
+
+ # Multiple statements cause syntax errors when wrapped as subquery
+ expect_error(
+ dplyr::collect(source$execute_query("SELECT 1 AS a; SELECT 2 AS b")),
+ regexp = "syntax error"
+ )
+ })
+
+ it("errors on empty SELECT (syntax error)", {
+ source <- local_tbl_sql_source()
+
+ expect_error(
+ dplyr::collect(source$execute_query("SELECT")),
+ regexp = NULL
+ )
+ })
+
+ it("errors on SELECT with no FROM when columns expected", {
+ source <- local_tbl_sql_source()
+
+ # SELECT without FROM is valid for literals but invalid for table columns
+ expect_error(
+ dplyr::collect(source$execute_query("SELECT id FROM")),
+ regexp = NULL
+ )
+ })
+
+ it("succeeds with query without trailing semicolon", {
+ source <- local_tbl_sql_source()
+
+ # Properly formed query without semicolon works
+ result <- source$execute_query("SELECT * FROM test_table")
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 5)
+ })
+})
+
+describe("TblSqlSource edge cases - Category B: Column Naming Issues", {
+ it("handles unnamed expressions (auto-generated names)", {
+ source <- local_tbl_sql_source()
+
+ # DuckDB auto-generates names for unnamed expressions
+ result <- source$execute_query(
+ "SELECT id, 1+1, UPPER(name) FROM test_table"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 5)
+ # Should have 3 columns (id + two computed)
+ expect_equal(ncol(collected), 3)
+ })
+
+ it("handles unnamed expressions in CTE mode", {
+ source <- local_tbl_sql_source(
+ tbl_transform = function(tbl) dplyr::filter(tbl, value > 10)
+ )
+
+ result <- source$execute_query(
+ "SELECT id, value * 2, UPPER(name) FROM test_table"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(ncol(collected), 3)
+ })
+
+ it("errors on duplicate column names from JOIN (tibble requirement)", {
+ skip_if_not_installed("duckdb")
+ skip_if_not_installed("dbplyr")
+ skip_if_not_installed("dplyr")
+
+ conn <- local_duckdb_multi_table()
+ tbl_a <- dplyr::tbl(conn, "table_a")
+ source <- TblSqlSource$new(tbl_a, "table_a")
+
+ # SELECT with explicit duplicate column names from JOIN
+ # DuckDB allows duplicate names but tibble rejects them on collect
+ result <- source$execute_query(
+ "SELECT table_a.id, table_b.id FROM table_a JOIN table_b ON table_a.id = table_b.id"
+ )
+ expect_error(
+ dplyr::collect(result),
+ regexp = "must not be duplicated|must be unique"
+ )
+ })
+
+ it("handles duplicate columns with aliases", {
+ skip_if_not_installed("duckdb")
+ skip_if_not_installed("dbplyr")
+ skip_if_not_installed("dplyr")
+
+ conn <- local_duckdb_multi_table()
+ tbl_a <- dplyr::tbl(conn, "table_a")
+ source <- TblSqlSource$new(tbl_a, "table_a")
+
+ # Using aliases to avoid duplicate column names
+ result <- source$execute_query(
+ "SELECT table_a.id AS id_a, table_b.id AS id_b FROM table_a JOIN table_b ON table_a.id = table_b.id"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 3)
+ expect_equal(ncol(collected), 2)
+ expect_true("id_a" %in% names(collected))
+ expect_true("id_b" %in% names(collected))
+ })
+
+ it("handles reserved word as alias", {
+ source <- local_tbl_sql_source()
+
+ # Using reserved word 'select' as column alias (quoted)
+ result <- source$execute_query(
+ "SELECT id AS \"select\", name AS \"from\" FROM test_table"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 5)
+ expect_true("select" %in% names(collected))
+ expect_true("from" %in% names(collected))
+ })
+
+ it("handles reserved word as unquoted alias (DuckDB permissive)", {
+ source <- local_tbl_sql_source()
+
+ # DuckDB is permissive with reserved words as aliases
+ # This may work or fail depending on the specific word
+ result <- source$execute_query(
+ "SELECT id AS value_alias FROM test_table"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 5)
+ })
+
+ it("handles empty string alias (DB-dependent)", {
+ source <- local_tbl_sql_source()
+
+ # Empty string alias - DuckDB behavior
+ # This typically creates a column with empty name or errors
+ expect_error(
+ {
+ result <- source$execute_query(
+ "SELECT id AS \"\" FROM test_table"
+ )
+ dplyr::collect(result)
+ },
+ regexp = NULL
+ )
+ })
+
+ it("errors on wildcard with JOIN (duplicate columns)", {
+ skip_if_not_installed("duckdb")
+ skip_if_not_installed("dbplyr")
+ skip_if_not_installed("dplyr")
+
+ conn <- local_duckdb_multi_table()
+ tbl_a <- dplyr::tbl(conn, "table_a")
+ source <- TblSqlSource$new(tbl_a, "table_a")
+
+ # SELECT * from JOIN produces duplicate 'id' columns
+ # tibble rejects duplicate names on collect
+ result <- source$execute_query(
+ "SELECT * FROM table_a JOIN table_b ON table_a.id = table_b.id"
+ )
+ expect_error(
+ dplyr::collect(result),
+ regexp = "must not be duplicated|must be unique"
+ )
+ })
+
+ it("handles wildcard with JOIN using USING clause (no duplicates)", {
+ skip_if_not_installed("duckdb")
+ skip_if_not_installed("dbplyr")
+ skip_if_not_installed("dplyr")
+
+ conn <- local_duckdb_multi_table()
+ tbl_a <- dplyr::tbl(conn, "table_a")
+ source <- TblSqlSource$new(tbl_a, "table_a")
+
+ # USING clause produces single 'id' column
+ result <- source$execute_query(
+ "SELECT * FROM table_a JOIN table_b USING (id)"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 3)
+ # id appears only once with USING
+ expect_equal(sum(names(collected) == "id"), 1)
+ })
+})
+
+describe("TblSqlSource edge cases - Category C: ORDER BY behavior", {
+ it("handles ORDER BY without LIMIT", {
+ source <- local_tbl_sql_source()
+
+ # ORDER BY without LIMIT is valid SQL
+ result <- source$execute_query(
+ "SELECT * FROM test_table ORDER BY value DESC"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 5)
+ # Verify order is applied
+ expect_equal(collected$value[1], 50)
+ expect_equal(collected$value[5], 10)
+ })
+
+ it("handles ORDER BY without LIMIT in CTE mode", {
+ source <- local_tbl_sql_source(
+ tbl_transform = function(tbl) dplyr::filter(tbl, value >= 10)
+ )
+
+ result <- source$execute_query(
+ "SELECT * FROM test_table ORDER BY value DESC"
+ )
+ collected <- dplyr::collect(result)
+ expect_true(nrow(collected) > 0)
+ # Verify order is maintained through CTE
+ expect_true(collected$value[1] >= collected$value[nrow(collected)])
+ })
+
+ it("handles LIMIT without ORDER BY (non-deterministic but valid)", {
+ source <- local_tbl_sql_source()
+
+ # LIMIT without ORDER BY is valid but non-deterministic
+ result <- source$execute_query("SELECT * FROM test_table LIMIT 3")
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 3)
+ })
+
+ it("handles ORDER BY with LIMIT", {
+ source <- local_tbl_sql_source()
+
+ result <- source$execute_query(
+ "SELECT * FROM test_table ORDER BY value DESC LIMIT 2"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 2)
+ expect_equal(collected$value, c(50, 40))
+ })
+
+ it("handles ORDER BY with LIMIT and OFFSET", {
+ source <- local_tbl_sql_source()
+
+ result <- source$execute_query(
+ "SELECT * FROM test_table ORDER BY value DESC LIMIT 2 OFFSET 1"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 2)
+ expect_equal(collected$value, c(40, 30))
+ })
+
+ it("handles ORDER BY with column alias", {
+ source <- local_tbl_sql_source()
+
+ result <- source$execute_query(
+ "SELECT id, value AS val FROM test_table ORDER BY val DESC"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 5)
+ expect_equal(collected$val[1], 50)
+ })
+
+ it("handles ORDER BY with expression", {
+ source <- local_tbl_sql_source()
+
+ result <- source$execute_query(
+ "SELECT id, value FROM test_table ORDER BY value * -1"
+ )
+ collected <- dplyr::collect(result)
+ expect_equal(nrow(collected), 5)
+ # ORDER BY value * -1 ascending means:
+ # -50 < -40 < -30 < -20 < -10
+ # So original values sorted: 50, 40, 30, 20, 10 (descending)
+ expect_equal(collected$value[1], 50)
+ expect_equal(collected$value[5], 10)
+ })
+})