I have trained a state transition model which uses X_state embedding key and has a decoder to output the full genes. now when I want to use state tx infer and specify the embed-key, it doesn't put the pert_cell_counts_preds output of the decoder in the X of the final h5ad output of the infer.
Here is the command:
state tx infer \
--output /media/model/state_output \
--adata /media/test_template.h5ad \
--pert-col target_gene \
--checkpoint /media/model/state_output/checkpoints/step=64000.ckpt \
--embed-key X_state