Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions acestep/pipeline_ace_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,19 @@ def text2music_diffusion_process(
if right_pad_frame_length > 0:
padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:])
target_latents = torch.cat(padd_list, dim=-1)

# Fix shape mismatch between target_latents and x0
if target_latents.shape[-1] != x0.shape[-1]:
if target_latents.shape[-1] < x0.shape[-1]:
# Pad with zeros if target_latents is shorter
padding = x0.shape[-1] - target_latents.shape[-1]
target_latents = torch.nn.functional.pad(
target_latents, (0, padding), "constant", 0
)
else:
# Trim if target_latents is longer
target_latents = target_latents[..., :x0.shape[-1]]

assert (
target_latents.shape[-1] == x0.shape[-1]
), f"{target_latents.shape=} {x0.shape=}"
Expand Down