Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class LatentFormat:
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
spacial_downscale_ratio = 8

def process_in(self, latent):
return latent * self.scale_factor
Expand Down Expand Up @@ -181,6 +182,7 @@ def process_out(self, latent):

class Flux2(LatentFormat):
latent_channels = 128
spacial_downscale_ratio = 16

def __init__(self):
self.latent_rgb_factors =[
Expand Down Expand Up @@ -749,6 +751,7 @@ class ACEAudio(LatentFormat):

class ChromaRadiance(LatentFormat):
latent_channels = 3
spacial_downscale_ratio = 1

def __init__(self):
self.latent_rgb_factors = [
Expand Down
27 changes: 17 additions & 10 deletions comfy/ldm/wan/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed

import comfy.ops
ops = comfy.ops.disable_weight_init
Expand All @@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
self._padding = 2 * self.padding[0]
self.padding = (0, self.padding[1], self.padding[2])

def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None:
cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None

padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
if cache_x is None and x.shape[2] == 1:
#Fast path - the op will pad for use by truncating the weight
#and save math on a pile of zeros.
return super().forward(x, autopad="causal_zero")

if self._padding > 0:
padding_needed = self._padding
if cache_x is not None:
cache_x = cache_x.to(x.device)
padding_needed = max(0, padding_needed - cache_x.shape[2])
padding_shape = list(x.shape)
padding_shape[2] = padding_needed
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
del cache_x
x = F.pad(x, padding)

return super().forward(x)

Expand Down
10 changes: 6 additions & 4 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
def reset_parameters(self):
return None

def _conv_forward(self, input, weight, bias, *args, **kwargs):
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
if autopad == "causal_zero":
weight = weight[:, :, -input.shape[2]:, :, :]
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
Expand All @@ -212,15 +214,15 @@ def _conv_forward(self, input, weight, bias, *args, **kwargs):
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)

def forward_comfy_cast_weights(self, input):
def forward_comfy_cast_weights(self, input, autopad=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
x = self._conv_forward(input, weight, bias, autopad=autopad)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
Expand Down
12 changes: 9 additions & 3 deletions comfy/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None):

return noises

def fix_empty_latent_channels(model, latent_image):
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if torch.count_nonzero(latent_image) == 0:
if latent_format.latent_channels != latent_image.shape[1]:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")

if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image
Expand Down
6 changes: 4 additions & 2 deletions comfy_extras/nodes_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler,
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
latent["samples"] = latent_image

if not add_noise:
Expand All @@ -760,6 +760,7 @@ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler,
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)

out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
if "x0" in x0_output:
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
Expand Down Expand Up @@ -939,7 +940,7 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
latent["samples"] = latent_image

noise_mask = None
Expand All @@ -954,6 +955,7 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
samples = samples.to(comfy.model_management.intermediate_device())

out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
if "x0" in x0_output:
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def define_schema(cls):
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent})
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})

generate = execute # TODO: remove

Expand Down
5 changes: 3 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ def INPUT_TYPES(s):

def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
return ({"samples": latent, "downscale_ratio_spacial": 8}, )


class LatentFromBatch:
Expand Down Expand Up @@ -1538,7 +1538,7 @@ def set_mask(self, samples, mask):

def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))

if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
Expand All @@ -1556,6 +1556,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
return (out, )

Expand Down
Loading