From 323217f3c40f9f894ac866ce38c03f77e92ab9c9 Mon Sep 17 00:00:00 2001 From: Ross Cutler <46252169+rosscutler@users.noreply.github.com> Date: Wed, 4 Jun 2025 14:33:43 -0700 Subject: [PATCH] Add initial pytest suite --- tests/conftest.py | 3 +++ tests/test_create_input.py | 33 +++++++++++++++++++++++++++++++++ tests/test_result_parser.py | 12 ++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_create_input.py create mode 100644 tests/test_result_parser.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8abafde --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,3 @@ +import os +import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) diff --git a/tests/test_create_input.py b/tests/test_create_input.py new file mode 100644 index 0000000..76ccb0f --- /dev/null +++ b/tests/test_create_input.py @@ -0,0 +1,33 @@ +import pandas as pd +import pytest + +from src.create_input import conv_filename_to_condition, validate_inputs + + +def test_conv_filename_to_condition_match(): + pattern = r"(?P[^_]+)_(?P\d+)\.wav" + result = conv_filename_to_condition("white_10.wav", pattern) + assert list(result.items()) == [("level", "10"), ("noise", "white")] + + +def test_conv_filename_to_condition_no_match(): + pattern = r"(?P[^_]+)_(?P\d+)\.wav" + result = conv_filename_to_condition("unexpected.wav", pattern) + assert result == {"Unknown": "NoCondition"} + + +def test_validate_inputs_acr_minimal(): + cfg = {"number_of_gold_clips_per_session": "0"} + df = pd.DataFrame(columns=[ + "rating_clips", "math", "pair_a", "pair_b", "trapping_clips", "trapping_ans" + ]) + validate_inputs(cfg, df, "acr") + + +def test_validate_inputs_missing_column(): + cfg = {"number_of_gold_clips_per_session": "0"} + df = pd.DataFrame(columns=[ + "rating_clips", "math", "pair_a", "trapping_clips", "trapping_ans" + ]) + with pytest.raises(AssertionError): + validate_inputs(cfg, df, "acr") diff --git a/tests/test_result_parser.py b/tests/test_result_parser.py new file mode 100644 index 0000000..1e41678 --- /dev/null +++ b/tests/test_result_parser.py @@ -0,0 +1,12 @@ +from src.result_parser import outliers_modified_z_score, outliers_z_score + + +def test_outliers_modified_z_score_removes_outlier(): + data = [1, 1, 1, 100] + assert outliers_modified_z_score(data) == [1, 1, 1] + + +def test_outliers_z_score_threshold_high(): + data = [10, 10, 10, 1000] + # With default threshold 3.29 the outlier is not removed + assert outliers_z_score(data) == data