Skip to content

Error when running model_inference on cath4.2 test set PDBs #7

@Vedaant-J

Description

@Vedaant-J

Hi,

I was running the given checkpoint using the code in model_inference pipeline on cath4.2 test set.

However I keep seeing errors and warnings for various proteins:

cath_download/all/1CJX.pdb Missing residue:  C 356 fill with 0
cath_download/all/1CJX.pdb Missing residue:  D 356 fill with 0
Failed 1cjx.B: The size of tensor a (1408) must match the size of tensor b (1200) at non-singleton dimension 1

cath_download/all/5FGO.pdb Missing residue:  A 214 fill with 0
cath_download/all/5FGO.pdb Missing residue:  D 254 fill with 0
Failed 5fgo.A: 'bool' object has no attribute 'atom_pos'

I guess the 'fill with 0' is okay and only occurs when certain coords are nan in the pdb, I would like to confirm this.

Also for getting the cath4.2 pdbs I use the similar approach as the cath_downloader as in this repo and download from RCSB_URL = "https://files.rcsb.org/view/{pdb}.pdb"

this is the exact script I am running:


import argparse
import os
import json
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from hydra import initialize, compose
from typing import List, Tuple

import sys
sys.path.append('/data/vvjain3/inverse_folding/MapDiff')
from model.egnn_pytorch.egnn_net import EGNN_NET
from model.ipa.ipa_net import IPANetPredictor
from model.prior_diff import Prior_Diff
from utils import enable_dropout
from dataloader.collator import CollatorDiff
from data.generate_graph_cath import pdb2graph, get_processed_graph, amino_acids_type


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)


def load_test_names(splits_path: str) -> List[str]:
    with open(splits_path, 'r') as f:
        splits = json.load(f)
    return list(splits.get('test', []))


def pdbid_from_name(name: str) -> str:
    return name.split('.')[0].upper()


def infer_one(model, device, pdb_path: str, ensemble: int, ddim_steps: int) -> Tuple[str, str, float, float]:
    graph_raw = pdb2graph(pdb_path)
    graph = get_processed_graph(graph_raw)
    collator = CollatorDiff()
    g_batch, ipa_batch = collator([graph])
    g_batch, ipa_batch = g_batch.to(device), ipa_batch.to(device)

    model.eval()
    with torch.no_grad():
        ens_logits = []
        enable_dropout(model)
        for _ in range(ensemble):
            logits, sample_graph = model.mc_ddim_sample(g_batch, ipa_batch, diverse=True, step=ddim_steps)
            ens_logits.append(logits)
        ens_logits_tensor = torch.stack(ens_logits)
        mean_logits = ens_logits_tensor.mean(dim=0).cpu()
        true_label = g_batch.x.cpu()
        true_seq = ''.join([amino_acids_type[i] for i in true_label.argmax(dim=1).tolist()])
        pred_seq = ''.join([amino_acids_type[i] for i in mean_logits.argmax(dim=1).tolist()])

        ll_fullseq = F.cross_entropy(mean_logits, true_label, reduction='mean').item()
        perplexity = float(np.exp(ll_fullseq))
        recovery = float((mean_logits.argmax(dim=1) == true_label.argmax(dim=1)).sum().item() / true_label.shape[0])
    return true_seq, pred_seq, perplexity, recovery


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--splits', required=True, help='Path to CATH chain_set_splits.json (4.2)')
    ap.add_argument('--pdb-dir', required=True, help='Directory containing downloaded PDB files (RCSB)')
    ap.add_argument('--out-dir', required=True, help='Output directory for predicted sequences and summary CSV')
    ap.add_argument('--device', default='cuda:0')
    args = ap.parse_args()

    ensure_dir(args.out_dir)
    seq_dir = os.path.join(args.out_dir, 'seqs')
    ensure_dir(seq_dir)

    conf_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'conf'))
    with initialize(version_base=None, config_path="conf"):
        cfg = compose(config_name="inference")

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')


    egnn = EGNN_NET(input_feat_dim=cfg.model.input_feat_dim, hidden_channels=cfg.model.hidden_dim,
                    edge_attr_dim=cfg.model.edge_attr_dim, dropout=cfg.model.drop_out,
                    n_layers=cfg.model.depth, update_edge=cfg.model.update_edge,
                    norm_coors=cfg.model.norm_coors, update_coors=cfg.model.update_coors,
                    update_global=cfg.model.update_global, embedding=cfg.model.embedding,
                    embedding_dim=cfg.model.embedding_dim, norm_feat=cfg.model.norm_feat,
                    embed_ss=cfg.model.embed_ss)
    ipa = IPANetPredictor(dropout=cfg.model.ipa_drop_out)
    model = Prior_Diff(egnn, ipa, timesteps=cfg.diffusion.timesteps,
                       objective=cfg.diffusion.objective,
                       noise_type=cfg.diffusion.noise_type,
                       sample_method=cfg.diffusion.sample_method,
                       min_mask_ratio=cfg.mask_prior.min_mask_ratio,
                       dev_mask_ratio=cfg.mask_prior.dev_mask_ratio,
                       marginal_dist_path=cfg.dataset.marginal_train_dir).to(device)

    # Load weights
    checkpoint = torch.load(cfg.test_model.path, map_location=device)
    model.load_state_dict(checkpoint['model'], strict=True)

    # Iterate over test list
    test_names = load_test_names(args.splits)
    summary_rows = []
    for name in tqdm(test_names, desc='MapDiff Inference'):
        pdbid = pdbid_from_name(name)
        pdb_path = os.path.join(args.pdb_dir, f"{pdbid}.pdb")
        if not os.path.exists(pdb_path):
            continue
        try:
            true_seq, pred_seq, ppl, rec = infer_one(model, device, pdb_path, cfg.diffusion.ensemble_num, cfg.diffusion.ddim_steps)
        except Exception as e:
            print(f"Failed {name}: {e}")
            continue


        out_file = os.path.join(seq_dir, f"{name}.fa")
        with open(out_file, 'w') as f:
            f.write(f">{name}\n")
            f.write(pred_seq + "\n")

        summary_rows.append((name, pdbid, len(pred_seq), ppl, rec))


    csv_path = os.path.join(args.out_dir, 'summary.csv')
    with open(csv_path, 'w') as f:
        f.write('name,pdbid,length,perplexity,recovery\n')
        for r in summary_rows:
            f.write(','.join([str(x) for x in r]) + '\n')

    print(f"Done. Wrote {len(summary_rows)} sequences to {seq_dir}")


if __name__ == '__main__':
    main()

Please let me know how I can fix the errors, they do not occur for every protein but for about 20-30% of them. I can provide a complete stack trace of the error if required.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions