Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions comfy/audio_encoders/audio_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(self, config):
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
4 changes: 4 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
DynamicVRAM = "dynamic_vram"

parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))

Expand Down Expand Up @@ -257,3 +258,6 @@ def is_valid_directory(path: str) -> str:
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)

def enables_dynamic_vram():
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
4 changes: 2 additions & 2 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(self, json_config):
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
2 changes: 1 addition & 1 deletion comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())

self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling
Expand Down
33 changes: 32 additions & 1 deletion comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
import math
import time
from functools import partial

from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import trange, tqdm
from tqdm.auto import trange as trange_, tqdm

from . import utils
from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling

import comfy.memory_management


def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)

pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")

_update = pbar.update

def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")

_update(n)

pbar.update = warmup_update
return pbar


def append_zero(x):
return torch.cat([x, x.new_zeros([1])])

Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/hunyuan_video/upsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def __init__(self, model_type, config):
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
81 changes: 81 additions & 0 deletions comfy/memory_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import math
import torch
from typing import NamedTuple

from comfy.quant_ops import QuantizedTensor

class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype

def element_size(self):
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
return info.bits // 8

def numel(self):
return math.prod(self.shape)

def tensors_to_geometries(tensors, dtype=None):
geometries = []
for t in tensors:
if t is None or isinstance(t, QuantizedTensor):
geometries.append(t)
continue
tdtype = t.dtype
if hasattr(t, "_model_dtype"):
tdtype = t._model_dtype
if dtype is not None:
tdtype = dtype
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
return geometries

def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])

if isinstance(tensor, QuantizedTensor):
inner_tensors, _ = tensor.__tensor_flatten__()
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])

if tensor is None:
return 0

size = tensor.numel() * tensor.element_size()
aligment_req = 1024
return (size + aligment_req - 1) // aligment_req * aligment_req

def interpret_gathered_like(tensors, gathered):
offset = 0
dest_views = []

if gathered.dim() != 1 or gathered.element_size() != 1:
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")

for tensor in tensors:

if tensor is None:
dest_views.append(None)
continue

if isinstance(tensor, QuantizedTensor):
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
else:
templates = { "data": tensor }

actuals = {}
for attr, template in templates.items():
size = template.numel() * template.element_size()
if offset + size > gathered.numel():
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
offset += vram_aligned_size(template)

if isinstance(tensor, QuantizedTensor):
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
else:
dest_views.append(actuals["data"])

return dest_views

aimdo_allocator = None
15 changes: 7 additions & 8 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)

comfy.model_management.archive_model_dtypes(self.diffusion_model)

self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
Expand Down Expand Up @@ -299,15 +301,15 @@ def extra_conds(self, **kwargs):

return out

def load_model_weights(self, sd, unet_prefix=""):
def load_model_weights(self, sd, unet_prefix="", assign=False):
to_load = {}
keys = list(sd.keys())
for k in keys:
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)

to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))

Expand All @@ -322,18 +324,15 @@ def process_latent_in(self, latent):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)

def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
if vae_state_dict is not None:
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))

unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)

if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])

Expand Down Expand Up @@ -776,8 +775,8 @@ def extra_conds(self, **kwargs):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
for k in d:
s = d[k]
Expand Down
Loading
Loading