-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
Hi,
I was running the given checkpoint using the code in model_inference pipeline on cath4.2 test set.
However I keep seeing errors and warnings for various proteins:
cath_download/all/1CJX.pdb Missing residue: C 356 fill with 0
cath_download/all/1CJX.pdb Missing residue: D 356 fill with 0
Failed 1cjx.B: The size of tensor a (1408) must match the size of tensor b (1200) at non-singleton dimension 1
cath_download/all/5FGO.pdb Missing residue: A 214 fill with 0
cath_download/all/5FGO.pdb Missing residue: D 254 fill with 0
Failed 5fgo.A: 'bool' object has no attribute 'atom_pos'
I guess the 'fill with 0' is okay and only occurs when certain coords are nan in the pdb, I would like to confirm this.
Also for getting the cath4.2 pdbs I use the similar approach as the cath_downloader as in this repo and download from RCSB_URL = "https://files.rcsb.org/view/{pdb}.pdb"
this is the exact script I am running:
import argparse
import os
import json
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from hydra import initialize, compose
from typing import List, Tuple
import sys
sys.path.append('/data/vvjain3/inverse_folding/MapDiff')
from model.egnn_pytorch.egnn_net import EGNN_NET
from model.ipa.ipa_net import IPANetPredictor
from model.prior_diff import Prior_Diff
from utils import enable_dropout
from dataloader.collator import CollatorDiff
from data.generate_graph_cath import pdb2graph, get_processed_graph, amino_acids_type
def ensure_dir(path: str):
os.makedirs(path, exist_ok=True)
def load_test_names(splits_path: str) -> List[str]:
with open(splits_path, 'r') as f:
splits = json.load(f)
return list(splits.get('test', []))
def pdbid_from_name(name: str) -> str:
return name.split('.')[0].upper()
def infer_one(model, device, pdb_path: str, ensemble: int, ddim_steps: int) -> Tuple[str, str, float, float]:
graph_raw = pdb2graph(pdb_path)
graph = get_processed_graph(graph_raw)
collator = CollatorDiff()
g_batch, ipa_batch = collator([graph])
g_batch, ipa_batch = g_batch.to(device), ipa_batch.to(device)
model.eval()
with torch.no_grad():
ens_logits = []
enable_dropout(model)
for _ in range(ensemble):
logits, sample_graph = model.mc_ddim_sample(g_batch, ipa_batch, diverse=True, step=ddim_steps)
ens_logits.append(logits)
ens_logits_tensor = torch.stack(ens_logits)
mean_logits = ens_logits_tensor.mean(dim=0).cpu()
true_label = g_batch.x.cpu()
true_seq = ''.join([amino_acids_type[i] for i in true_label.argmax(dim=1).tolist()])
pred_seq = ''.join([amino_acids_type[i] for i in mean_logits.argmax(dim=1).tolist()])
ll_fullseq = F.cross_entropy(mean_logits, true_label, reduction='mean').item()
perplexity = float(np.exp(ll_fullseq))
recovery = float((mean_logits.argmax(dim=1) == true_label.argmax(dim=1)).sum().item() / true_label.shape[0])
return true_seq, pred_seq, perplexity, recovery
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--splits', required=True, help='Path to CATH chain_set_splits.json (4.2)')
ap.add_argument('--pdb-dir', required=True, help='Directory containing downloaded PDB files (RCSB)')
ap.add_argument('--out-dir', required=True, help='Output directory for predicted sequences and summary CSV')
ap.add_argument('--device', default='cuda:0')
args = ap.parse_args()
ensure_dir(args.out_dir)
seq_dir = os.path.join(args.out_dir, 'seqs')
ensure_dir(seq_dir)
conf_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'conf'))
with initialize(version_base=None, config_path="conf"):
cfg = compose(config_name="inference")
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
egnn = EGNN_NET(input_feat_dim=cfg.model.input_feat_dim, hidden_channels=cfg.model.hidden_dim,
edge_attr_dim=cfg.model.edge_attr_dim, dropout=cfg.model.drop_out,
n_layers=cfg.model.depth, update_edge=cfg.model.update_edge,
norm_coors=cfg.model.norm_coors, update_coors=cfg.model.update_coors,
update_global=cfg.model.update_global, embedding=cfg.model.embedding,
embedding_dim=cfg.model.embedding_dim, norm_feat=cfg.model.norm_feat,
embed_ss=cfg.model.embed_ss)
ipa = IPANetPredictor(dropout=cfg.model.ipa_drop_out)
model = Prior_Diff(egnn, ipa, timesteps=cfg.diffusion.timesteps,
objective=cfg.diffusion.objective,
noise_type=cfg.diffusion.noise_type,
sample_method=cfg.diffusion.sample_method,
min_mask_ratio=cfg.mask_prior.min_mask_ratio,
dev_mask_ratio=cfg.mask_prior.dev_mask_ratio,
marginal_dist_path=cfg.dataset.marginal_train_dir).to(device)
# Load weights
checkpoint = torch.load(cfg.test_model.path, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
# Iterate over test list
test_names = load_test_names(args.splits)
summary_rows = []
for name in tqdm(test_names, desc='MapDiff Inference'):
pdbid = pdbid_from_name(name)
pdb_path = os.path.join(args.pdb_dir, f"{pdbid}.pdb")
if not os.path.exists(pdb_path):
continue
try:
true_seq, pred_seq, ppl, rec = infer_one(model, device, pdb_path, cfg.diffusion.ensemble_num, cfg.diffusion.ddim_steps)
except Exception as e:
print(f"Failed {name}: {e}")
continue
out_file = os.path.join(seq_dir, f"{name}.fa")
with open(out_file, 'w') as f:
f.write(f">{name}\n")
f.write(pred_seq + "\n")
summary_rows.append((name, pdbid, len(pred_seq), ppl, rec))
csv_path = os.path.join(args.out_dir, 'summary.csv')
with open(csv_path, 'w') as f:
f.write('name,pdbid,length,perplexity,recovery\n')
for r in summary_rows:
f.write(','.join([str(x) for x in r]) + '\n')
print(f"Done. Wrote {len(summary_rows)} sequences to {seq_dir}")
if __name__ == '__main__':
main()
Please let me know how I can fix the errors, they do not occur for every protein but for about 20-30% of them. I can provide a complete stack trace of the error if required.
Metadata
Metadata
Assignees
Labels
No labels