diff --git a/acestep/pipeline_ace_step.py b/acestep/pipeline_ace_step.py index 01207779..a734d1a6 100644 --- a/acestep/pipeline_ace_step.py +++ b/acestep/pipeline_ace_step.py @@ -1046,6 +1046,19 @@ def text2music_diffusion_process( if right_pad_frame_length > 0: padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:]) target_latents = torch.cat(padd_list, dim=-1) + + # Fix shape mismatch between target_latents and x0 + if target_latents.shape[-1] != x0.shape[-1]: + if target_latents.shape[-1] < x0.shape[-1]: + # Pad with zeros if target_latents is shorter + padding = x0.shape[-1] - target_latents.shape[-1] + target_latents = torch.nn.functional.pad( + target_latents, (0, padding), "constant", 0 + ) + else: + # Trim if target_latents is longer + target_latents = target_latents[..., :x0.shape[-1]] + assert ( target_latents.shape[-1] == x0.shape[-1] ), f"{target_latents.shape=} {x0.shape=}"