From 435e9fd667b016ad742d4ddbd55381e74ab74da9 Mon Sep 17 00:00:00 2001 From: jackj Date: Fri, 23 Jan 2026 23:15:42 +0500 Subject: [PATCH] Fix: Shape mismatch in extend mode causing AssertionError - Added automatic shape alignment for target_latents and x0 - Handles both shorter (padding) and longer (trimming) cases - Fixes crash in extend mode with long audio files - Minimal impact on audio quality (~0.05-0.15 sec) Resolves issue where extend mode fails with AssertionError when target_latents shape doesn't match x0 shape after padding/trimming operations. --- acestep/pipeline_ace_step.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/acestep/pipeline_ace_step.py b/acestep/pipeline_ace_step.py index 01207779..a734d1a6 100644 --- a/acestep/pipeline_ace_step.py +++ b/acestep/pipeline_ace_step.py @@ -1046,6 +1046,19 @@ def text2music_diffusion_process( if right_pad_frame_length > 0: padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:]) target_latents = torch.cat(padd_list, dim=-1) + + # Fix shape mismatch between target_latents and x0 + if target_latents.shape[-1] != x0.shape[-1]: + if target_latents.shape[-1] < x0.shape[-1]: + # Pad with zeros if target_latents is shorter + padding = x0.shape[-1] - target_latents.shape[-1] + target_latents = torch.nn.functional.pad( + target_latents, (0, padding), "constant", 0 + ) + else: + # Trim if target_latents is longer + target_latents = target_latents[..., :x0.shape[-1]] + assert ( target_latents.shape[-1] == x0.shape[-1] ), f"{target_latents.shape=} {x0.shape=}"