From ebf2e1c4439ae4112c4191806261a9a30206b361 Mon Sep 17 00:00:00 2001 From: Ross Cutler <46252169+rosscutler@users.noreply.github.com> Date: Wed, 4 Jun 2025 14:33:00 -0700 Subject: [PATCH] Fix balanced block shuffling in CCR --- src/create_input.py | 60 +++++++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/src/create_input.py b/src/create_input.py index ac9ee88..eae12b2 100644 --- a/src/create_input.py +++ b/src/create_input.py @@ -445,14 +445,21 @@ def create_input_for_dcrccr(cfg, df, output_path): output_df = pd.DataFrame() packing_strategy = cfg.get("clip_packing_strategy", "random").strip().lower() - if packing_strategy == "balanced_block": - add_clips_balanced_block_ccr(clips, refs, cfg["condition_pattern"], cfg.get("block_keys", cfg["condition_keys"]) - , n_clips_per_session, output_df) - elif packing_strategy == "random": - add_clips_random_ccr(clips, refs, n_clips_per_session, output_df) - - n_sessions = math.ceil(n_clips / n_clips_per_session) - print(f'{n_clips} clips and {n_sessions} sessions') + if packing_strategy == "balanced_block": + add_clips_balanced_block_ccr( + clips, + refs, + cfg["condition_pattern"], + cfg.get("block_keys", cfg["condition_keys"]), + n_clips_per_session, + output_df, + ) + elif packing_strategy == "random": + add_clips_random_ccr(clips, refs, n_clips_per_session, output_df) + + # number of sessions equals the number of rows in output_df + n_sessions = len(output_df) + print(f"{n_clips} clips and {n_sessions} sessions") # create math math_source = df['math'].dropna() @@ -482,24 +489,25 @@ def create_input_for_dcrccr(cfg, df, output_path): 'CMP4_A': new_4[:, 6], 'CMP4_B': new_4[:, 7]}) # add math output_df['math'] = math_output - # rating_clips - # repeat some clips to have a full design - n_questions = int(cfg['number_of_clips_per_session']) - needed_clips = n_sessions * n_questions - - full_clips = np.tile(clips.to_numpy(), (needed_clips // n_clips) + 1)[:needed_clips] - full_refs = np.tile(refs.to_numpy(), (needed_clips // n_clips) + 1)[:needed_clips] - - full = list(zip(full_clips, full_refs)) - random.shuffle(full) - full_clips, full_refs = zip(*full) - - clips_sessions = np.reshape(full_clips, (n_sessions, n_questions)) - refs_sessions = np.reshape(full_refs, (n_sessions, n_questions)) - - for q in range(n_questions): - output_df[f'Q{q}_P'] = clips_sessions[:, q] - output_df[f'Q{q}_R'] = refs_sessions[:, q] + if packing_strategy == "random": + # rating_clips + # repeat some clips to have a full design + n_questions = int(cfg["number_of_clips_per_session"]) + needed_clips = n_sessions * n_questions + + full_clips = np.tile(clips.to_numpy(), (needed_clips // n_clips) + 1)[:needed_clips] + full_refs = np.tile(refs.to_numpy(), (needed_clips // n_clips) + 1)[:needed_clips] + + full = list(zip(full_clips, full_refs)) + random.shuffle(full) + full_clips, full_refs = zip(*full) + + clips_sessions = np.reshape(full_clips, (n_sessions, n_questions)) + refs_sessions = np.reshape(full_refs, (n_sessions, n_questions)) + + for q in range(n_questions): + output_df[f"Q{q}_P"] = clips_sessions[:, q] + output_df[f"Q{q}_R"] = refs_sessions[:, q] # trappings if int(cfg['number_of_trapping_per_session']) > 0: