Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from ..utils import deprecate, is_scipy_available
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput

from ..utils import logging
logger = logging.get_logger(__name__)

if is_scipy_available():
import scipy.stats
Expand Down Expand Up @@ -411,29 +412,34 @@ def set_timesteps(
if self.config.use_karras_sigmas:
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.beta_schedule != "squaredcos_cap_v2":
timesteps = timesteps.round()
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps)

elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.beta_schedule != "squaredcos_cap_v2":
timesteps = timesteps.round()
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps)

elif self.config.use_exponential_sigmas:
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps)

elif self.config.use_beta_sigmas:
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps)

elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()

else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

Expand Down Expand Up @@ -544,6 +550,38 @@ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
t = t.reshape(sigma.shape)
return t

def _ensure_unique_timesteps(self, timesteps, sigmas, num_inference_steps):
"""
Ensure timesteps are unique and handle duplicates while preserving the correspondence with sigmas.

Args:
timesteps (`np.ndarray`):
The timestep values that may contain duplicates.
sigmas (`np.ndarray`):
The sigma values corresponding to the timesteps.
num_inference_steps (`int`):
The number of inference steps originally requested.

Returns:
`Tuple[np.ndarray, np.ndarray]`:
A tuple of (timesteps, sigmas) where timesteps are unique and sigmas are filtered accordingly.
"""
unique_timesteps, unique_indices = np.unique(timesteps, return_index=True)

if len(unique_timesteps) < len(timesteps):
# Sort by original indices to maintain order
unique_indices_sorted = np.sort(unique_indices)
timesteps = timesteps[unique_indices_sorted]
sigmas = sigmas[unique_indices_sorted]

if len(timesteps) < num_inference_steps:
logger.warning(
f"Due to the current scheduler configuration, only {len(timesteps)} unique timesteps "
f"could be generated instead of the requested {num_inference_steps}."
)

return timesteps, sigmas

def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert sigma values to alpha_t and sigma_t values.
Expand Down
28 changes: 28 additions & 0 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,31 @@ def test_beta_sigmas(self):

def test_exponential_sigmas(self):
self.check_over_configs(use_exponential_sigmas=True)

def test_no_duplicate_timesteps_with_sigma_methods(self):
sigma_configs = [
{"use_karras_sigmas": True},
{"use_lu_lambdas": True},
{"use_exponential_sigmas": True},
{"use_beta_sigmas": True},
]

for config in sigma_configs:
scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=1000,
beta_schedule="squaredcos_cap_v2",
**config,
)
scheduler.set_timesteps(20)

sample = torch.randn(4, 3, 32, 32)

try:
for t in scheduler.timesteps:
model_output = torch.randn_like(sample)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
except IndexError as e:
self.fail(f"Index error occurred with config {config}: {e}")
except Exception as e:
self.fail(f"Unexpected error with config {config}: {e}")