diff --git a/bindcraft.py b/bindcraft.py index 55758f2..6848b08 100644 --- a/bindcraft.py +++ b/bindcraft.py @@ -1,8 +1,27 @@ #################################### ###################### BindCraft Run #################################### + +# MODIFICATION FOR PARALLEL EXECUTION +# Minimal changes from v1.5.1 to allow safe execution as multiple concurrent processes +# using either a job scheduler (like SLURM) or local background jobs. +# +# 1. File locking is used to prevent data corruption when writing to shared CSV files. +# 2. A central counter file (_progress_counter.txt) tracks the total accepted designs, +# allowing all processes to stop once the target is reached. +# +# Requires FileLock library: +# conda install -c conda-forge filelock +# +# Example SLURM usage: +# #SBATCH --array=0-31 +# +# Example local background jobs: +# for i in {1..8}; do nohup python -u ./bindcraft.py --settings './settings_target/PDL1.json' --filters './settings_filters/default_filters.json' --advanced './settings_advanced/default_4stage_multimer.json' > output_${i}.log 2> error_${i}.log & done + ### Import dependencies from functions import * +from filelock import FileLock # Check if JAX-capable GPU is available, otherwise exit check_jax_gpu() @@ -48,10 +67,28 @@ final_csv = os.path.join(target_settings["design_path"], 'final_design_stats.csv') failure_csv = os.path.join(target_settings["design_path"], 'failure_csv.csv') -create_dataframe(trajectory_csv, trajectory_labels) -create_dataframe(mpnn_csv, design_labels) -create_dataframe(final_csv, final_labels) -generate_filter_pass_csv(failure_csv, args.filters) +# Define paths for the progress counter and all shared CSV files and their locks +progress_counter_file = os.path.join(target_settings["design_path"], '_progress_counter.txt') +progress_lock = FileLock(progress_counter_file + ".lock") + +trajectory_csv_lock = FileLock(trajectory_csv + ".lock") +mpnn_csv_lock = FileLock(mpnn_csv + ".lock") +final_csv_lock = FileLock(final_csv + ".lock") +failure_csv_lock = FileLock(failure_csv + ".lock") +finalization_lock = FileLock(os.path.join(target_settings["design_path"], "_finalization.lock")) + +# Initialize counter file if it doesn't exist. This is safe for multiple processes. +with progress_lock: + if not os.path.exists(progress_counter_file): + with open(progress_counter_file, 'w') as f: + f.write('0') + +# Initialize dataframes safely by passing the corresponding lock to each function +create_dataframe(trajectory_csv, trajectory_labels, trajectory_csv_lock) +create_dataframe(mpnn_csv, design_labels, mpnn_csv_lock) +create_dataframe(final_csv, final_labels, final_csv_lock) +generate_filter_pass_csv(failure_csv, args.filters, failure_csv_lock) + #################################### #################################### @@ -70,11 +107,13 @@ ### start design loop while True: - ### check if we have the target number of binders - final_designs_reached = check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels) - - if final_designs_reached: - # stop design loop execution + # Check global progress counter before starting a new trajectory + with progress_lock: + with open(progress_counter_file, 'r') as f: + accepted_count = int(f.read()) + + if accepted_count >= target_settings["number_of_final_designs"]: + print(f"Target of {target_settings['number_of_final_designs']} designs reached. Worker process is shutting down.") break ### check if we reached maximum allowed trajectories @@ -108,7 +147,7 @@ ### Begin binder hallucination trajectory = binder_hallucination(design_name, target_settings["starting_pdb"], target_settings["chains"], target_settings["target_hotspot_residues"], length, seed, helicity_value, - design_models, advanced_settings, design_paths, failure_csv) + design_models, advanced_settings, design_paths, failure_csv, failure_csv_lock) trajectory_metrics = copy_dict(trajectory._tmp["best"]["aux"]["log"]) # contains plddt, ptm, i_ptm, pae, i_pae trajectory_pdb = os.path.join(design_paths["Trajectory"], design_name + ".pdb") @@ -159,7 +198,8 @@ trajectory_interface_scores['interface_hbond_percentage'], trajectory_interface_scores['interface_delta_unsat_hbonds'], trajectory_interface_scores['interface_delta_unsat_hbonds_percentage'], trajectory_alpha_interface, trajectory_beta_interface, trajectory_loops_interface, trajectory_alpha, trajectory_beta, trajectory_loops, trajectory_interface_AA, trajectory_target_rmsd, trajectory_time_text, traj_seq_notes, settings_file, filters_file, advanced_file] - insert_data(trajectory_csv, trajectory_data) + with trajectory_csv_lock: + insert_data(trajectory_csv, trajectory_data) if advanced_settings["enable_mpnn"]: # initialise MPNN counters @@ -170,7 +210,8 @@ ### MPNN redesign of starting binder mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings) - existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values) + with mpnn_csv_lock: + existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values) # create set of MPNN sequences with allowed amino acid composition restricted_AAs = set(aa.strip().upper() for aa in advanced_settings["omit_AAs"].split(',')) if advanced_settings["force_reject_AA"] else set() @@ -232,7 +273,7 @@ mpnn_sequence['seq'], mpnn_design_name, target_settings["starting_pdb"], target_settings["chains"], length, trajectory_pdb, prediction_models, advanced_settings, - filters, design_paths, failure_csv) + filters, design_paths, failure_csv, failure_csv_lock) # if AF2 filters are not passed then skip the scoring if not pass_af2_filters: @@ -330,7 +371,6 @@ mpnn_end_time = time.time() - mpnn_time elapsed_mpnn_text = f"{'%d hours, %d minutes, %d seconds' % (int(mpnn_end_time // 3600), int((mpnn_end_time % 3600) // 60), int(mpnn_end_time % 60))}" - # Insert statistics about MPNN design into CSV, will return None if corresponding model does note exist model_numbers = range(1, 6) statistics_labels = ['pLDDT', 'pTM', 'i_pTM', 'pAE', 'i_pAE', 'i_pLDDT', 'ss_pLDDT', 'Unrelaxed_Clashes', 'Relaxed_Clashes', 'Binder_Energy_Score', 'Surface_Hydrophobicity', @@ -357,7 +397,8 @@ mpnn_data.extend([elapsed_mpnn_text, seq_notes, settings_file, filters_file, advanced_file]) # insert data into csv - insert_data(mpnn_csv, mpnn_data) + with mpnn_csv_lock: + insert_data(mpnn_csv, mpnn_data) # find best model number by pLDDT plddt_values = {i: mpnn_data[i] for i in range(11, 15) if mpnn_data[i] is not None} @@ -381,13 +422,24 @@ # insert data into final csv final_data = [''] + mpnn_data - insert_data(final_csv, final_data) + with final_csv_lock: + insert_data(final_csv, final_data) + + # Safely increment the global progress counter + with progress_lock: + with open(progress_counter_file, 'r') as f: + current_count = int(f.read()) + new_count = current_count + 1 + with open(progress_counter_file, 'w') as f: + f.write(str(new_count)) # copy animation from accepted trajectory if advanced_settings["save_design_animations"]: accepted_animation = os.path.join(design_paths["Accepted/Animation"], f"{design_name}.html") if not os.path.exists(accepted_animation): - shutil.copy(os.path.join(design_paths["Trajectory/Animation"], f"{design_name}.html"), accepted_animation) + source_animation = os.path.join(design_paths["Trajectory/Animation"], f"{design_name}.html") + if os.path.exists(source_animation): + shutil.copy(source_animation, accepted_animation) # copy plots of accepted trajectory plot_files = os.listdir(design_paths["Trajectory/Plots"]) @@ -396,25 +448,34 @@ source_plot = os.path.join(design_paths["Trajectory/Plots"], accepted_plot) target_plot = os.path.join(design_paths["Accepted/Plots"], accepted_plot) if not os.path.exists(target_plot): - shutil.copy(source_plot, target_plot) + if os.path.exists(source_plot): + shutil.copy(source_plot, target_plot) + + # If this process just saved the final design, it will trigger the ranking. + if new_count >= target_settings["number_of_final_designs"]: + print(f"FINAL DESIGN ({new_count}) FOUND! TRIGGERING FINAL RANKING...") + check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels, finalization_lock, mpnn_csv_lock, final_csv_lock) else: print(f"Unmet filter conditions for {mpnn_design_name}") - failure_df = pd.read_csv(failure_csv) - special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_') - incremented_columns = set() - - for column in filter_conditions: - base_column = column - for prefix in special_prefixes: - if column.startswith(prefix): - base_column = column.split('_', 1)[1] - - if base_column not in incremented_columns: - failure_df[base_column] = failure_df[base_column] + 1 - incremented_columns.add(base_column) - - failure_df.to_csv(failure_csv, index=False) + + with failure_csv_lock: + failure_df = pd.read_csv(failure_csv) + special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_') + incremented_columns = set() + + for column in filter_conditions: + base_column = column + for prefix in special_prefixes: + if column.startswith(prefix): + base_column = column.split('_', 1)[1] + + if base_column not in incremented_columns: + failure_df[base_column] = failure_df[base_column] + 1 + incremented_columns.add(base_column) + + failure_df.to_csv(failure_csv, index=False) + shutil.copy(best_model_pdb, design_paths["Rejected"]) # increase MPNN design number @@ -457,6 +518,5 @@ gc.collect() ### Script finished -elapsed_time = time.time() - script_start_time -elapsed_text = f"{'%d hours, %d minutes, %d seconds' % (int(elapsed_time // 3600), int((elapsed_time % 3600) // 60), int(elapsed_time % 60))}" -print("Finished all designs. Script execution for "+str(trajectory_n)+" trajectories took: "+elapsed_text) \ No newline at end of file +# The final summary block is removed, as it is not meaningful in a parallel context. Each worker will exit on its own when the global target is met. +print("Finished all designs.")