From 2d06c86595218fe62eed8bb3bc3c5a30f6d8a227 Mon Sep 17 00:00:00 2001 From: mmb78 <62362216+mmb78@users.noreply.github.com> Date: Tue, 24 Jun 2025 22:35:25 +0200 Subject: [PATCH 1/3] Update of bindcraft.py to allow safe parallel execution # MODIFICATION FOR PARALLEL EXECUTION # Minimal changes from v1.5.1 to allow safe execution as a SLURM job array. # # 1. File locking is used to prevent data corruption when writing to shared CSV files. # 2. A central counter file (_progress_counter.txt) tracks the total accepted designs, # allowing all processes to stop once the target is reached. # # Example SLURM Usage: # #SBATCH --array=0-31 # # requires FileLock library - run "pip install filelock" --- bindcraft.py | 113 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 83 insertions(+), 30 deletions(-) diff --git a/bindcraft.py b/bindcraft.py index 55758f2..f0a9e09 100644 --- a/bindcraft.py +++ b/bindcraft.py @@ -1,8 +1,22 @@ #################################### ###################### BindCraft Run #################################### + +# MODIFICATION FOR PARALLEL EXECUTION +# Minimal changes from v1.5.1 to allow safe execution as a SLURM job array. +# +# 1. File locking is used to prevent data corruption when writing to shared CSV files. +# 2. A central counter file (_progress_counter.txt) tracks the total accepted designs, +# allowing all processes to stop once the target is reached. +# +# Example SLURM Usage: +# #SBATCH --array=0-31 +# +# requires FileLock library - run "pip install filelock" + ### Import dependencies from functions import * +from filelock import FileLock # Check if JAX-capable GPU is available, otherwise exit check_jax_gpu() @@ -48,10 +62,28 @@ final_csv = os.path.join(target_settings["design_path"], 'final_design_stats.csv') failure_csv = os.path.join(target_settings["design_path"], 'failure_csv.csv') -create_dataframe(trajectory_csv, trajectory_labels) -create_dataframe(mpnn_csv, design_labels) -create_dataframe(final_csv, final_labels) -generate_filter_pass_csv(failure_csv, args.filters) +# Define paths for the progress counter and all shared CSV files and their locks +progress_counter_file = os.path.join(target_settings["design_path"], '_progress_counter.txt') +progress_lock = FileLock(progress_counter_file + ".lock") + +trajectory_csv_lock = FileLock(trajectory_csv + ".lock") +mpnn_csv_lock = FileLock(mpnn_csv + ".lock") +final_csv_lock = FileLock(final_csv + ".lock") +failure_csv_lock = FileLock(failure_csv + ".lock") +finalization_lock = FileLock(os.path.join(target_settings["design_path"], "_finalization.lock")) + +# Initialize counter file if it doesn't exist. This is safe for multiple processes. +with progress_lock: + if not os.path.exists(progress_counter_file): + with open(progress_counter_file, 'w') as f: + f.write('0') + +# Initialize dataframes safely by passing the corresponding lock to each function +create_dataframe(trajectory_csv, trajectory_labels, trajectory_csv_lock) +create_dataframe(mpnn_csv, design_labels, mpnn_csv_lock) +create_dataframe(final_csv, final_labels, final_csv_lock) +generate_filter_pass_csv(failure_csv, args.filters, failure_csv_lock) + #################################### #################################### @@ -70,8 +102,17 @@ ### start design loop while True: + # Check global progress counter before starting a new trajectory + with progress_lock: + with open(progress_counter_file, 'r') as f: + accepted_count = int(f.read()) + + if accepted_count >= target_settings["number_of_final_designs"]: + print(f"Target of {target_settings['number_of_final_designs']} designs reached. Worker process is shutting down.") + break + ### check if we have the target number of binders - final_designs_reached = check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels) + final_designs_reached = check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock) if final_designs_reached: # stop design loop execution @@ -108,7 +149,7 @@ ### Begin binder hallucination trajectory = binder_hallucination(design_name, target_settings["starting_pdb"], target_settings["chains"], target_settings["target_hotspot_residues"], length, seed, helicity_value, - design_models, advanced_settings, design_paths, failure_csv) + design_models, advanced_settings, design_paths, failure_csv, failure_csv_lock) trajectory_metrics = copy_dict(trajectory._tmp["best"]["aux"]["log"]) # contains plddt, ptm, i_ptm, pae, i_pae trajectory_pdb = os.path.join(design_paths["Trajectory"], design_name + ".pdb") @@ -159,7 +200,8 @@ trajectory_interface_scores['interface_hbond_percentage'], trajectory_interface_scores['interface_delta_unsat_hbonds'], trajectory_interface_scores['interface_delta_unsat_hbonds_percentage'], trajectory_alpha_interface, trajectory_beta_interface, trajectory_loops_interface, trajectory_alpha, trajectory_beta, trajectory_loops, trajectory_interface_AA, trajectory_target_rmsd, trajectory_time_text, traj_seq_notes, settings_file, filters_file, advanced_file] - insert_data(trajectory_csv, trajectory_data) + with trajectory_csv_lock: + insert_data(trajectory_csv, trajectory_data) if advanced_settings["enable_mpnn"]: # initialise MPNN counters @@ -170,7 +212,8 @@ ### MPNN redesign of starting binder mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings) - existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values) + with mpnn_csv_lock: + existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values) # create set of MPNN sequences with allowed amino acid composition restricted_AAs = set(aa.strip().upper() for aa in advanced_settings["omit_AAs"].split(',')) if advanced_settings["force_reject_AA"] else set() @@ -232,7 +275,7 @@ mpnn_sequence['seq'], mpnn_design_name, target_settings["starting_pdb"], target_settings["chains"], length, trajectory_pdb, prediction_models, advanced_settings, - filters, design_paths, failure_csv) + filters, design_paths, failure_csv, failure_csv_lock) # if AF2 filters are not passed then skip the scoring if not pass_af2_filters: @@ -330,7 +373,6 @@ mpnn_end_time = time.time() - mpnn_time elapsed_mpnn_text = f"{'%d hours, %d minutes, %d seconds' % (int(mpnn_end_time // 3600), int((mpnn_end_time % 3600) // 60), int(mpnn_end_time % 60))}" - # Insert statistics about MPNN design into CSV, will return None if corresponding model does note exist model_numbers = range(1, 6) statistics_labels = ['pLDDT', 'pTM', 'i_pTM', 'pAE', 'i_pAE', 'i_pLDDT', 'ss_pLDDT', 'Unrelaxed_Clashes', 'Relaxed_Clashes', 'Binder_Energy_Score', 'Surface_Hydrophobicity', @@ -357,7 +399,8 @@ mpnn_data.extend([elapsed_mpnn_text, seq_notes, settings_file, filters_file, advanced_file]) # insert data into csv - insert_data(mpnn_csv, mpnn_data) + with mpnn_csv_lock: + insert_data(mpnn_csv, mpnn_data) # find best model number by pLDDT plddt_values = {i: mpnn_data[i] for i in range(11, 15) if mpnn_data[i] is not None} @@ -381,7 +424,15 @@ # insert data into final csv final_data = [''] + mpnn_data - insert_data(final_csv, final_data) + with final_csv_lock: + insert_data(final_csv, final_data) + + # Safely increment the global progress counter + with progress_lock: + with open(progress_counter_file, 'r') as f: + current_count = int(f.read()) + with open(progress_counter_file, 'w') as f: + f.write(str(current_count + 1)) # copy animation from accepted trajectory if advanced_settings["save_design_animations"]: @@ -400,21 +451,24 @@ else: print(f"Unmet filter conditions for {mpnn_design_name}") - failure_df = pd.read_csv(failure_csv) - special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_') - incremented_columns = set() - - for column in filter_conditions: - base_column = column - for prefix in special_prefixes: - if column.startswith(prefix): - base_column = column.split('_', 1)[1] - - if base_column not in incremented_columns: - failure_df[base_column] = failure_df[base_column] + 1 - incremented_columns.add(base_column) - - failure_df.to_csv(failure_csv, index=False) + + with failure_csv_lock: + failure_df = pd.read_csv(failure_csv) + special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_') + incremented_columns = set() + + for column in filter_conditions: + base_column = column + for prefix in special_prefixes: + if column.startswith(prefix): + base_column = column.split('_', 1)[1] + + if base_column not in incremented_columns: + failure_df[base_column] = failure_df[base_column] + 1 + incremented_columns.add(base_column) + + failure_df.to_csv(failure_csv, index=False) + shutil.copy(best_model_pdb, design_paths["Rejected"]) # increase MPNN design number @@ -457,6 +511,5 @@ gc.collect() ### Script finished -elapsed_time = time.time() - script_start_time -elapsed_text = f"{'%d hours, %d minutes, %d seconds' % (int(elapsed_time // 3600), int((elapsed_time % 3600) // 60), int(elapsed_time % 60))}" -print("Finished all designs. Script execution for "+str(trajectory_n)+" trajectories took: "+elapsed_text) \ No newline at end of file +# The final summary block is removed, as it is not meaningful in a parallel context. Each worker will exit on its own when the global target is met. +print("Finished all designs.") From 3dd0f3115b64e29602914cebb5e96f39c10ebfdf Mon Sep 17 00:00:00 2001 From: mmb78 <62362216+mmb78@users.noreply.github.com> Date: Thu, 26 Jun 2025 00:36:55 +0200 Subject: [PATCH 2/3] Update bindcraft.py allow safe parallel execution The process that finds the last desired accepted binder will execute ranking and final processing of all data. All other parallel processes will stop after finishing their current job. This version fixes the problem of my previous version where no processes would do the final ranking. --- bindcraft.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/bindcraft.py b/bindcraft.py index f0a9e09..6a157ea 100644 --- a/bindcraft.py +++ b/bindcraft.py @@ -3,16 +3,21 @@ #################################### # MODIFICATION FOR PARALLEL EXECUTION -# Minimal changes from v1.5.1 to allow safe execution as a SLURM job array. +# Minimal changes from v1.5.1 to allow safe execution as multiple concurrent processes +# using either a job scheduler (like SLURM) or local background jobs. # # 1. File locking is used to prevent data corruption when writing to shared CSV files. # 2. A central counter file (_progress_counter.txt) tracks the total accepted designs, # allowing all processes to stop once the target is reached. # -# Example SLURM Usage: +# Requires FileLock library: +# conda install -c conda-forge filelock +# +# Example SLURM usage: # #SBATCH --array=0-31 # -# requires FileLock library - run "pip install filelock" +# Example local background jobs: +# for i in {1..8}; do nohup python -u ./bindcraft.py --settings './settings_target/PDL1.json' --filters './settings_filters/default_filters.json' --advanced './settings_advanced/default_4stage_multimer.json' > output_${i}.log 2> error_${i}.log & done ### Import dependencies from functions import * @@ -111,13 +116,6 @@ print(f"Target of {target_settings['number_of_final_designs']} designs reached. Worker process is shutting down.") break - ### check if we have the target number of binders - final_designs_reached = check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock) - - if final_designs_reached: - # stop design loop execution - break - ### check if we reached maximum allowed trajectories max_trajectories_reached = check_n_trajectories(design_paths, advanced_settings) @@ -431,8 +429,14 @@ with progress_lock: with open(progress_counter_file, 'r') as f: current_count = int(f.read()) + new_count = current_count + 1 with open(progress_counter_file, 'w') as f: - f.write(str(current_count + 1)) + f.write(str(new_count)) + + # If this process just saved the final design, it will trigger the ranking. + if new_count >= target_settings["number_of_final_designs"]: + print(f"FINAL DESIGN ({new_count}) FOUND! TRIGGERING FINAL RANKING...") + check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock) # copy animation from accepted trajectory if advanced_settings["save_design_animations"]: From 18a4e77ab7a83a4e47c7171dab7dde2c0ad8e4a5 Mon Sep 17 00:00:00 2001 From: mmb78 <62362216+mmb78@users.noreply.github.com> Date: Thu, 26 Jun 2025 22:41:09 +0200 Subject: [PATCH 3/3] Update bindcraft.py to allow safe parallel processing There was one more bug in the final processing step. Order of final processing was rearranged to prevent crash during raking of binders and generating outputs. --- bindcraft.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/bindcraft.py b/bindcraft.py index 6a157ea..6848b08 100644 --- a/bindcraft.py +++ b/bindcraft.py @@ -433,16 +433,13 @@ with open(progress_counter_file, 'w') as f: f.write(str(new_count)) - # If this process just saved the final design, it will trigger the ranking. - if new_count >= target_settings["number_of_final_designs"]: - print(f"FINAL DESIGN ({new_count}) FOUND! TRIGGERING FINAL RANKING...") - check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock) - # copy animation from accepted trajectory if advanced_settings["save_design_animations"]: accepted_animation = os.path.join(design_paths["Accepted/Animation"], f"{design_name}.html") if not os.path.exists(accepted_animation): - shutil.copy(os.path.join(design_paths["Trajectory/Animation"], f"{design_name}.html"), accepted_animation) + source_animation = os.path.join(design_paths["Trajectory/Animation"], f"{design_name}.html") + if os.path.exists(source_animation): + shutil.copy(source_animation, accepted_animation) # copy plots of accepted trajectory plot_files = os.listdir(design_paths["Trajectory/Plots"]) @@ -451,7 +448,13 @@ source_plot = os.path.join(design_paths["Trajectory/Plots"], accepted_plot) target_plot = os.path.join(design_paths["Accepted/Plots"], accepted_plot) if not os.path.exists(target_plot): - shutil.copy(source_plot, target_plot) + if os.path.exists(source_plot): + shutil.copy(source_plot, target_plot) + + # If this process just saved the final design, it will trigger the ranking. + if new_count >= target_settings["number_of_final_designs"]: + print(f"FINAL DESIGN ({new_count}) FOUND! TRIGGERING FINAL RANKING...") + check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock) else: print(f"Unmet filter conditions for {mpnn_design_name}")