From bdeac8897e522b9637a6a427fdc8a50a6abd6b20 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 21 Jan 2026 15:36:02 -0800 Subject: [PATCH 1/6] feat: Add search_aliases field to node schema (#12010) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add search_aliases field to node schema Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate). Changes: - Add `search_aliases: list[str]` to V3 Schema - Add `SEARCH_ALIASES` support for V1 nodes - Include field in `/object_info` response - Add aliases to high-priority core nodes V1 usage: ```python class MyNode: SEARCH_ALIASES = ["alt name", "synonym"] ``` V3 usage: ```python io.Schema( node_id="MyNode", search_aliases=["alt name", "synonym"], ... ) ``` ## Related PRs - Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this) - Docs: Comfy-Org/docs#XXXX (draft - merge after stable) * Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1 --- comfy_api/latest/_io.py | 4 ++++ comfy_extras/nodes_post_processing.py | 1 + comfy_extras/nodes_preview_any.py | 1 + comfy_extras/nodes_string.py | 1 + comfy_extras/nodes_upscale_model.py | 1 + nodes.py | 15 +++++++++++++++ server.py | 2 ++ 7 files changed, 25 insertions(+) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4969d3506436..a60020ca8257 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1249,6 +1249,7 @@ class NodeInfoV1: experimental: bool=None api_node: bool=None price_badge: dict | None = None + search_aliases: list[str]=None @dataclass class NodeInfoV3: @@ -1346,6 +1347,8 @@ class Schema: hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" + search_aliases: list[str] = field(default_factory=list) + """Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming.""" is_input_list: bool = False """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. @@ -1483,6 +1486,7 @@ def get_v1_info(self, cls) -> NodeInfoV1: api_node=self.is_api_node, python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, + search_aliases=self.search_aliases if self.search_aliases else None, ) return info diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 2e559c35c9ff..6011275d6370 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -550,6 +550,7 @@ def define_schema(cls): node_id="BatchImagesNode", display_name="Batch Images", category="image", + search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) ], diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 139b07c93e9e..91502ebf2b01 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,6 +16,7 @@ def INPUT_TYPES(cls): OUTPUT_NODE = True CATEGORY = "utils" + SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"] def main(self, source=None): value = 'None' diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 571d89f6268f..a2d5f0d94142 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -11,6 +11,7 @@ def define_schema(cls): node_id="StringConcatenate", display_name="Concatenate", category="utils/string", + search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ed587851c521..97b9e948d619 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -53,6 +53,7 @@ def define_schema(cls): node_id="ImageUpscaleWithModel", display_name="Upscale Image (using Model)", category="image/upscaling", + search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"], inputs=[ io.UpscaleModel.Input("upscale_model"), io.Image.Input("image"), diff --git a/nodes.py b/nodes.py index ea5d6e525d4b..67b61dcfeeb2 100644 --- a/nodes.py +++ b/nodes.py @@ -70,6 +70,7 @@ def INPUT_TYPES(s) -> InputTypeDict: CATEGORY = "conditioning" DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." + SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"] def encode(self, clip, text): if clip is None: @@ -86,6 +87,7 @@ def INPUT_TYPES(s): FUNCTION = "combine" CATEGORY = "conditioning" + SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"] def combine(self, conditioning_1, conditioning_2): return (conditioning_1 + conditioning_2, ) @@ -294,6 +296,7 @@ def INPUT_TYPES(s): CATEGORY = "latent" DESCRIPTION = "Decodes latent images back into pixel space images." + SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"] def decode(self, vae, samples): latent = samples["samples"] @@ -346,6 +349,7 @@ def INPUT_TYPES(s): FUNCTION = "encode" CATEGORY = "latent" + SEARCH_ALIASES = ["encode", "encode image", "image to latent"] def encode(self, vae, pixels): t = vae.encode(pixels) @@ -581,6 +585,7 @@ def INPUT_TYPES(s): CATEGORY = "loaders" DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." + SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"] def load_checkpoint(self, ckpt_name): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -667,6 +672,7 @@ def INPUT_TYPES(s): CATEGORY = "loaders" DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." + SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"] def load_lora(self, model, clip, lora_name, strength_model, strength_clip): if strength_model == 0 and strength_clip == 0: @@ -814,6 +820,7 @@ def INPUT_TYPES(s): FUNCTION = "load_controlnet" CATEGORY = "loaders" + SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"] def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) @@ -890,6 +897,7 @@ def INPUT_TYPES(s): FUNCTION = "apply_controlnet" CATEGORY = "conditioning/controlnet" + SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"] def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): if strength == 0: @@ -1200,6 +1208,7 @@ def INPUT_TYPES(s): CATEGORY = "latent" DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling." + SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) @@ -1540,6 +1549,7 @@ def INPUT_TYPES(s): CATEGORY = "sampling" DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." + SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"] def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) @@ -1604,6 +1614,7 @@ def INPUT_TYPES(s): CATEGORY = "image" DESCRIPTION = "Saves the input images to your ComfyUI output directory." + SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"] def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append @@ -1640,6 +1651,8 @@ def __init__(self): self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.compress_level = 1 + SEARCH_ALIASES = ["preview", "preview image", "show image", "view image", "display image", "image viewer"] + @classmethod def INPUT_TYPES(s): return {"required": @@ -1658,6 +1671,7 @@ def INPUT_TYPES(s): } CATEGORY = "image" + SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"] RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" @@ -1810,6 +1824,7 @@ def INPUT_TYPES(s): FUNCTION = "upscale" CATEGORY = "image/upscaling" + SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"] def upscale(self, image, upscale_method, width, height, crop): if width == 0 and height == 0: diff --git a/server.py b/server.py index 04a5774883d2..1888745b7438 100644 --- a/server.py +++ b/server.py @@ -682,6 +682,8 @@ def node_info(node_class): if hasattr(obj_class, 'API_NODE'): info['api_node'] = obj_class.API_NODE + + info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', []) return info @routes.get("/object_info") From abe2ec26a61ff670b9c0e71e4821c873368c8728 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:44:28 -0800 Subject: [PATCH 2/6] Support the Anima model. (#12012) --- comfy/ldm/anima/model.py | 202 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 22 ++++ comfy/model_detection.py | 2 + comfy/sd.py | 7 ++ comfy/supported_models.py | 33 +++++- comfy/text_encoders/anima.py | 61 +++++++++++ 6 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/anima/model.py create mode 100644 comfy/text_encoders/anima.py diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py new file mode 100644 index 000000000000..2e6ed58fa49e --- /dev/null +++ b/comfy/ldm/anima/model.py @@ -0,0 +1,202 @@ +from comfy.ldm.cosmos.predict2 import MiniTrainDIT +import torch +from torch import nn +import torch.nn.functional as F + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim): + super().__init__() + self.rope_theta = 10000 + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + + self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): + context = x if context is None else context + input_shape = x.shape[:-1] + q_shape = (*input_shape, self.n_heads, self.head_dim) + context_shape = context.shape[:-1] + kv_shape = (*context_shape, self.n_heads, self.head_dim) + + query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) + value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) + + if position_embeddings is not None: + assert position_embeddings_context is not None + cos, sin = position_embeddings + query_states = apply_rotary_pos_emb(query_states, cos, sin) + cos, sin = position_embeddings_context + key_states = apply_rotary_pos_emb(key_states, cos, sin) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) + + attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + def init_weights(self): + torch.nn.init.zeros_(self.o_proj.weight) + + +class TransformerBlock(nn.Module): + def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): + super().__init__() + self.use_self_attn = use_self_attn + + if self.use_self_attn: + self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention( + query_dim=model_dim, + context_dim=model_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + query_dim=model_dim, + context_dim=source_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), + nn.GELU(), + operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) + ) + + def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): + if self.use_self_attn: + normed = self.norm_self_attn(x) + attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) + x = x + attn_out + + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + x = x + attn_out + + x = x + self.mlp(self.norm_mlp(x)) + return x + + def init_weights(self): + torch.nn.init.zeros_(self.mlp[2].weight) + self.cross_attn.init_weights() + + +class LLMAdapter(nn.Module): + def __init__( + self, + source_dim=1024, + target_dim=1024, + model_dim=1024, + num_layers=6, + num_heads=16, + use_self_attn=True, + layer_norm=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) + if model_dim != target_dim: + self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) + else: + self.in_proj = nn.Identity() + self.rotary_emb = RotaryEmbedding(model_dim//num_heads) + self.blocks = nn.ModuleList([ + TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) + ]) + self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) + self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) + + def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) + + x = self.in_proj(self.embed(target_input_ids)) + context = source_hidden_states + position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) + position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_context = self.rotary_emb(x, position_ids_context) + for block in self.blocks: + x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + return self.norm(self.out_proj(x)) + + +class Anima(MiniTrainDIT): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) + + def preprocess_text_embeds(self, text_embeds, text_ids): + if text_ids is not None: + return self.llm_adapter(text_embeds, text_ids) + else: + return text_embeds diff --git a/comfy/model_base.py b/comfy/model_base.py index 28ba2643e334..1d57562ccd8a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -49,6 +49,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model +import comfy.ldm.anima.model import comfy.model_management import comfy.patcher_extension @@ -1147,6 +1148,27 @@ def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): sigma = (sigma / (sigma + 1)) return latent_image / (1.0 - sigma) +class Anima(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + t5xxl_ids = kwargs.get("t5xxl_ids", None) + t5xxl_weights = kwargs.get("t5xxl_weights", None) + device = kwargs["device"] + if cross_attn is not None: + if t5xxl_ids is not None: + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device)) + if t5xxl_weights is not None: + cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + + if cross_attn.shape[1] < 512: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1])) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dad206a2f842..b29a033ccb82 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -550,6 +550,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 dit_config = {} dit_config["image_model"] = "cosmos_predict2" + if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "anima" dit_config["max_img_h"] = 240 dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 diff --git a/comfy/sd.py b/comfy/sd.py index 77700dfd3b50..f7f6a44a0651 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -57,6 +57,7 @@ import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie +import comfy.text_encoders.anima import comfy.model_patcher import comfy.lora @@ -1048,6 +1049,7 @@ class TEModel(Enum): GEMMA_3_12B = 18 JINA_CLIP_2 = 19 QWEN3_8B = 20 + QWEN3_06B = 21 def detect_te_model(sd): @@ -1093,6 +1095,8 @@ def detect_te_model(sd): return TEModel.QWEN3_2B elif weight.shape[0] == 4096: return TEModel.QWEN3_8B + elif weight.shape[0] == 1024: + return TEModel.QWEN3_06B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1233,6 +1237,9 @@ class EmptyClass: elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper + elif te_model == TEModel.QWEN3_06B: + clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c8a7f6efb598..70abebf46c4c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -23,6 +23,7 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image +import comfy.text_encoders.anima from . import supported_models_base from . import latent_formats @@ -992,6 +993,36 @@ def clip_target(self, state_dict={}): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) +class Anima(supported_models_base.BASE): + unet_config = { + "image_model": "anima", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Anima(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect)) + class CosmosI2VPredict2(CosmosT2IPredict2): unet_config = { "image_model": "cosmos_predict2", @@ -1551,6 +1582,6 @@ def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py new file mode 100644 index 000000000000..41f95bcb6600 --- /dev/null +++ b/comfy/text_encoders/anima.py @@ -0,0 +1,61 @@ +from transformers import Qwen2Tokenizer, T5TokenizerFast +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch + + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class AnimaTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) + out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + return self.t5xxl.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class AnimaTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out = super().encode_token_weights(token_weight_pairs) + out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int) + out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0]))) + return out + +def te(dtype_llama=None, llama_quantization_metadata=None): + class AnimaTEModel_(AnimaTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return AnimaTEModel_ From f09904720dc8b56bc6823ebdaf5de69465448e46 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 21 Jan 2026 20:01:35 -0800 Subject: [PATCH 3/6] Fix for edge case of EasyCache when conditionings change during a sampling run (like with timestep scheduling) (#12020) --- comfy_extras/nodes_easycache.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 11b23ffdbd84..90d730df6283 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs): do_easycache = easycache.should_do_easycache(sigmas) if do_easycache: easycache.check_metadata(x) + # if there isn't a cache diff for current conds, we cannot skip this step + can_apply_cache_diff = easycache.can_apply_cache_diff(uuids) # if first cond marked this step for skipping, skip it and use appropriate cached values - if easycache.skip_current_step: + if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") return easycache.apply_cache_diff(x, uuids) @@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm easycache.cumulative_change_rate += approx_output_change_rate - if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values @@ -240,6 +242,9 @@ def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> t return to_return.clone() return to_return + def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: + return all(uuid in self.uuid_cache_diffs for uuid in uuids) + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): if self.first_cond_uuid in uuids: self.total_steps_skipped += 1 From 3365ad18a5e0c86b23c6272e5adcedd333fc45cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:03:51 +0200 Subject: [PATCH 4/6] Support LTX2 tiny vae (taeltx_2) (#11929) --- comfy/sd.py | 5 ++--- comfy/taesd/taehv.py | 51 ++++++++++++++++++++++++++++---------------- latent_preview.py | 2 +- nodes.py | 2 +- 4 files changed, 37 insertions(+), 23 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index f7f6a44a0651..ce7e6bcff563 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -636,14 +636,13 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) - if self.latent_channels == 48: # Wan 2.2 + if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_input = self.process_output = lambda image: image self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 0e5f9a378b58..6c06ce19d1d5 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): class TAEHV(nn.Module): - def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True), + latent_format=None, show_progress_bar=False): super().__init__() self.image_channels = 3 self.patch_size = 1 @@ -124,6 +125,9 @@ def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 self.patch_size = 2 + elif self.latent_channels == 128: # LTX2 + self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True) + if self.latent_channels == 32: # HunyuanVideo1.5 act_func = nn.LeakyReLU(0.2, inplace=True) else: # HunyuanVideo, Wan 2.1 @@ -131,41 +135,52 @@ def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, self.encoder = nn.Sequential( conv(self.image_channels*self.patch_size**2, 64), act_func, - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), conv(64, self.latent_channels), ) n_f = [256, 128, 64, 64] - self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( Clamp(), conv(self.latent_channels, n_f[0]), act_func, - MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), - MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), - MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False), act_func, conv(n_f[3], self.image_channels*self.patch_size**2), ) - @property - def show_progress_bar(self): - return self._show_progress_bar - @show_progress_bar.setter - def show_progress_bar(self, value): - self._show_progress_bar = value + self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool)) + self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow)) + self.frames_to_trim = self.t_upscale - 1 + self._show_progress_bar = show_progress_bar + + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value def encode(self, x, **kwargs): + x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] if self.patch_size > 1: + B, T, C, H, W = x.shape + x = x.reshape(B * T, C, H, W) x = F.pixel_unshuffle(x, self.patch_size) - x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if x.shape[1] % 4 != 0: - # pad at end to multiple of 4 - n_pad = 4 - x.shape[1] % 4 + x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) + if x.shape[1] % self.t_downscale != 0: + # pad at end to multiple of t_downscale + n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): + x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] + x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) if self.patch_size > 1: diff --git a/latent_preview.py b/latent_preview.py index d52e3f7a1fd0..a9d777661a83 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -11,7 +11,7 @@ default_preview_method = args.preview_method MAX_PREVIEW_RESOLUTION = args.preview_size -VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] def preview_to_image(latent_image, do_scale=True): if do_scale: diff --git a/nodes.py b/nodes.py index 67b61dcfeeb2..8864fda605cb 100644 --- a/nodes.py +++ b/nodes.py @@ -707,7 +707,7 @@ def load_lora_model_only(self, model, lora_name, strength_model): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: - video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod def vae_list(s): From 245f6139b65899112d11ff294d36a820f2d69496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:05:06 +0200 Subject: [PATCH 5/6] More targeted embedding_connector loading for LTX2 text encoder (#11992) Reduces errors --- comfy/text_encoders/lt.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index c33c77db7ff4..e491619640e3 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -118,9 +118,18 @@ def load_sd(self, sd): sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) if len(sdo) == 0: sdo = sd - missing, unexpected = self.load_state_dict(sdo, strict=False) - missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model - return (missing, unexpected) + + missing_all = [] + unexpected_all = [] + + for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: + component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} + if component_sd: + missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing_all.extend([f"{prefix}{k}" for k in missing]) + unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) + + return (missing_all, unexpected_all) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0 From 16b9aabd52c3b81b365fbf562bbcc4528111ef6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 22 Jan 2026 06:09:48 +0200 Subject: [PATCH 6/6] Support Multi/InfiniteTalk (#10179) * re-init * Update model_multitalk.py * whitespace... * Update model_multitalk.py * remove print * this is redundant * remove import * Restore preview functionality * Move block_idx to transformer_options * Remove LoopingSamplerCustomAdvanced * Remove looping functionality, keep extension functionality * Update model_multitalk.py * Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn * Chunk attention map calculation for multiple speakers to reduce peak VRAM usage * Update model_multitalk.py * Add ModelPatch type back * Fix for latest upstream * Use DynamicCombo for cleaner node Basically just so that single_speaker mode hides mask inputs and 2nd audio input * Update nodes_wan.py --- comfy/ldm/wan/model.py | 17 +- comfy/ldm/wan/model_multitalk.py | 500 ++++++++++++++++++++++++++++++ comfy_api/latest/_io.py | 3 +- comfy_extras/nodes_model_patch.py | 41 +++ comfy_extras/nodes_wan.py | 169 +++++++++- 5 files changed, 727 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/wan/model_multitalk.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4216ce83150d..ea123acb402a 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -62,6 +62,8 @@ def forward(self, x, freqs, transformer_options={}): x(Tensor): Shape [B, L, num_heads, C / num_heads] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ + patches = transformer_options.get("patches", {}) + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim def qkv_fn_q(x): @@ -86,6 +88,10 @@ def qkv_fn_k(x): transformer_options=transformer_options, ) + if "attn1_patch" in patches: + for p in patches["attn1_patch"]: + x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) + x = self.o(x) return x @@ -225,6 +231,8 @@ def forward( """ # assert e.dtype == torch.float32 + patches = transformer_options.get("patches", {}) + if e.ndim < 4: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -242,6 +250,11 @@ def forward( # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + + if "attn2_patch" in patches: + for p in patches["attn2_patch"]: + x = p({"x": x, "transformer_options": transformer_options}) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -488,7 +501,7 @@ def __init__(self, self.blocks = nn.ModuleList([ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) - for _ in range(num_layers) + for i in range(num_layers) ]) # head @@ -541,6 +554,7 @@ def forward_orig( # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings @@ -738,6 +752,7 @@ def forward_orig( # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py new file mode 100644 index 000000000000..c9dd98c4ddc7 --- /dev/null +++ b/comfy/ldm/wan/model_multitalk.py @@ -0,0 +1,500 @@ +import torch +from einops import rearrange, repeat +import comfy +from comfy.ldm.modules.attention import optimized_attention + + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8): + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q.transpose(1, 2) * scale + + B, H, x_seqlens, K = visual_q.shape + + x_ref_attn_maps = [] + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask.view(1, 1, 1, -1) + + x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype) + chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens) + + for i in range(0, x_seqlens, chunk_size): + end_i = min(i + chunk_size, x_seqlens) + + attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens + + # Apply softmax + attn_max = attn_chunk.max(dim=-1, keepdim=True).values + attn_chunk = (attn_chunk - attn_max).exp() + attn_sum = attn_chunk.sum(dim=-1, keepdim=True) + attn_chunk = attn_chunk / (attn_sum + 1e-8) + + # Apply mask and sum + masked_attn = attn_chunk * ref_target_mask + x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8) + + del attn_chunk, masked_attn + + # Average across heads + x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens + x_ref_attn_maps.append(x_ref_attnmap) + + del visual_q, ref_k + + return torch.cat(x_ref_attn_maps, dim=0) + +def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2): + """Args: + query (torch.tensor): B M H K + key (torch.tensor): B M H K + shape (tuple): (N_t, N_h, N_w) + ref_target_masks: [B, N_h * N_w] + """ + + N_t, N_h, N_w = shape + + x_seqlens = N_h * N_w + ref_k = ref_k[:, :x_seqlens] + _, seq_lens, heads, _ = visual_q.shape + class_num, _ = ref_target_masks.shape + x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q) + + split_chunk = heads // split_num + + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map( + visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_target_masks + ) + x_ref_attn_maps += x_ref_attn_maps_perhead + + return x_ref_attn_maps / split_num + + +def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): + source_min, source_max = source_range + new_min, new_max = target_range + normalized = (column - source_min) / (source_max - source_min + epsilon) + scaled = normalized * (new_max - new_min) + new_min + return scaled + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def get_audio_embeds(encoded_audio, audio_start, audio_end): + audio_embs = [] + human_num = len(encoded_audio) + audio_frames = encoded_audio[0].shape[0] + + indices = (torch.arange(4 + 1) - 2) * 1 + + for human_idx in range(human_num): + if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence + pad_len = audio_end - audio_frames + pad_shape = list(encoded_audio[human_idx].shape) + pad_shape[0] = pad_len + pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1))) + encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0) + else: + encoded_audio_in = encoded_audio[human_idx] + center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1) + audio_emb = encoded_audio_in[center_indices].unsqueeze(0) + audio_embs.append(audio_emb) + + return torch.cat(audio_embs, dim=0) + + +def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end): + audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end) + + first_frame_audio_emb_s = audio_embs[:, :1, ...] + latter_frame_audio_emb = audio_embs[:, 1:, ...] + latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4) + + middle_index = audio_proj.seq_len // 2 + + latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] + latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] + latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] + latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) + + audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) + audio_emb = torch.cat(audio_emb.split(1), dim=2) + + return audio_emb + + +class RotaryPositionalEmbedding1D(torch.nn.Module): + def __init__(self, + head_dim, + ): + super().__init__() + self.head_dim = head_dim + self.base = 10000 + + def precompute_freqs_cis_1d(self, pos_indices): + freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) + freqs = freqs.to(pos_indices.device) + freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def forward(self, x, pos_indices): + freqs_cis = self.precompute_freqs_cis_1d(pos_indices) + + x_ = x.float() + + freqs_cis = freqs_cis.float().to(x.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + x_ = (x_ * cos) + (rotate_half(x_) * sin) + + return x_.type_as(x) + +class SingleStreamAttention(torch.nn.Module): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + device=None, dtype=None, operations=None + ) -> None: + super().__init__() + self.dim = dim + self.encoder_hidden_states_dim = encoder_hidden_states_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor: + N_t, N_h, N_w = shape + + expected_tokens = N_t * N_h * N_w + actual_tokens = x.shape[1] + x_extra = None + + if actual_tokens != expected_tokens: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + + B = x.shape[0] + S = N_h * N_w + x = x.view(B * N_t, S, self.dim) + + # get q for hidden_state + q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim) + + # get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim) + kv = self.kv_linear(encoder_hidden_states) + encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2) + + #print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128]) + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # linear transform + x = self.proj(x.reshape(B * N_t, S, self.dim)) + x = x.view(B, N_t * S, self.dim) + + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + +class SingleStreamMultiAttention(SingleStreamAttention): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + class_range: int = 24, + class_interval: int = 4, + device=None, dtype=None, operations=None + ) -> None: + super().__init__( + dim=dim, + encoder_hidden_states_dim=encoder_hidden_states_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + device=device, + dtype=dtype, + operations=operations + ) + + # Rotary-embedding layout parameters + self.class_interval = class_interval + self.class_range = class_range + self.max_humans = self.class_range // self.class_interval + + # Constant bucket used for background tokens + self.rope_bak = int(self.class_range // 2) + + self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + shape=None, + x_ref_attn_map=None + ) -> torch.Tensor: + encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device) + human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1 + # Single-speaker fall-through + if human_num <= 1: + return super().forward(x, encoder_hidden_states, shape) + + N_t, N_h, N_w = shape + + x_extra = None + if x.shape[0] * N_t != encoder_hidden_states.shape[0]: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # Query projection + B, N, C = x.shape + q = self.q_linear(x) + q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # Use `class_range` logic for 2 speakers + rope_h1 = (0, self.class_interval) + rope_h2 = (self.class_range - self.class_interval, self.class_range) + rope_bak = int(self.class_range // 2) + + # Normalize and scale attention maps for each speaker + max_values = x_ref_attn_map.max(1).values[:, None, None] + min_values = x_ref_attn_map.min(1).values[:, None, None] + max_min_values = torch.cat([max_values, min_values], dim=2) + + human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() + human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() + + human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1) + human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2) + back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device) + + # Token-wise speaker dominance + max_indices = x_ref_attn_map.argmax(dim=0) + normalized_map = torch.stack([human1, human2, back], dim=1) + normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices] + + # Apply rotary to Q + q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + q = self.rope_1d(q, normalized_pos) + q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Keys / Values + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + encoder_k, encoder_v = encoder_kv.unbind(0) + + # Rotary for keys – assign centre of each speaker bucket to its context tokens + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device) + per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2 + per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2 + encoder_pos = torch.cat([per_frame] * N_t, dim=0) + + encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + encoder_k = self.rope_1d(encoder_k, encoder_pos) + encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Final attention + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # Linear projection + x = x.reshape(B, N, C) + x = self.proj(x) + + # Restore original layout + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + + +class MultiTalkAudioProjModel(torch.nn.Module): + def __init__( + self, + seq_len: int = 5, + seq_len_vf: int = 12, + blocks: int = 12, + channels: int = 768, + intermediate_dim: int = 512, + out_dim: int = 768, + context_tokens: int = 32, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels + self.input_dim_vf = seq_len_vf * blocks * channels + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.out_dim = out_dim + + # define multiple linear layers + self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype) + self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype) + self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype) + self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype) + + def forward(self, audio_embeds, audio_embeds_vf): + video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] + B, _, _, S, C = audio_embeds.shape + + # process audio of first frame + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + # process audio of latter frame + audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") + batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape + audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) + + # first projection + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) + audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) + audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) + audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) + batch_size_c, N_t, C_a = audio_embeds_c.shape + audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) + + # second projection + audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) + + context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim) + + # normalization and reshape + context_tokens = self.norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class WanMultiTalkAttentionBlock(torch.nn.Module): + def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None): + super().__init__() + self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations) + self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True) + + +class MultiTalkGetAttnMapPatch: + def __init__(self, ref_target_masks=None): + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + x = kwargs["x"] + + if self.ref_target_masks is not None: + x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device)) + transformer_options["x_ref_attn_map"] = x_ref_attn_map + return x + + +class MultiTalkCrossAttnPatch: + def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None): + self.model_patch = model_patch + self.audio_scale = audio_scale + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + block_idx = transformer_options.get("block_index", None) + x = kwargs["x"] + if block_idx is None: + return torch.zeros_like(x) + + audio_embeds = transformer_options.get("audio_embeds") + x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None) + + norm_x = self.model_patch.model.blocks[block_idx].norm_x(x) + x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn( + norm_x, audio_embeds.to(x.dtype), + shape=transformer_options["grid_sizes"], + x_ref_attn_map=x_ref_attn_map + ) + x = x + x_audio * self.audio_scale + return x + + def models(self): + return [self.model_patch] + +class MultiTalkApplyModelWrapper: + def __init__(self, init_latents): + self.init_latents = init_latents + + def __call__(self, executor, x, *args, **kwargs): + x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x) + samples = executor(x, *args, **kwargs) + return samples + + +class InfiniteTalkOuterSampleWrapper: + def __init__(self, motion_frames_latent, model_patch, is_extend=False): + self.motion_frames_latent = motion_frames_latent + self.model_patch = model_patch + self.is_extend = is_extend + + def __call__(self, executor, *args, **kwargs): + model_patcher = executor.class_obj.model_patcher + model_options = executor.class_obj.model_options + process_latent_in = model_patcher.model.process_latent_in + + # for InfiniteTalk, model input first latent(s) need to always be replaced on every step + if self.motion_frames_latent is not None: + wrappers = model_options["transformer_options"]["wrappers"] + w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}) + w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))] + + # run the sampling process + result = executor(*args, **kwargs) + + # insert motion frames before decoding + if self.is_extend: + overlap = self.motion_frames_latent.shape[2] + result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2) + + return result + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + if self.motion_frames_latent is not None: + self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype) + return self diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a60020ca8257..2ec8d6e4be76 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -754,7 +754,7 @@ class AnyType(ComfyTypeIO): Type = Any @comfytype(io_type="MODEL_PATCH") -class MODEL_PATCH(ComfyTypeIO): +class ModelPatch(ComfyTypeIO): Type = Any @comfytype(io_type="AUDIO_ENCODER") @@ -2038,6 +2038,7 @@ def as_dict(self) -> dict: "ControlNet", "Vae", "Model", + "ModelPatch", "ClipVision", "ClipVisionOutput", "AudioEncoder", diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index f66d28fc9214..82c4754a3f79 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,6 +7,7 @@ import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel class BlockWiseControlBlock(torch.nn.Module): @@ -257,6 +258,14 @@ def load_model_patch(self, name): if torch.count_nonzero(ref_weight) == 0: config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) + elif "audio_proj.proj1.weight" in sd: + model = MultiTalkModelPatch( + audio_window=5, context_tokens=32, vae_scale=4, + in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0], + intermediate_dim=sd["audio_proj.proj1.weight"].shape[0], + out_dim=sd["audio_proj.norm.weight"].shape[0], + device=comfy.model_management.unet_offload_device(), + operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -524,6 +533,38 @@ def apply_patch(self, model, model_patch, clip_vision_output): return (model_patched,) +class MultiTalkModelPatch(torch.nn.Module): + def __init__( + self, + audio_window: int = 5, + intermediate_dim: int = 512, + in_dim: int = 5120, + out_dim: int = 768, + context_tokens: int = 32, + vae_scale: int = 4, + num_layers: int = 40, + + device=None, dtype=None, operations=None + ): + super().__init__() + self.audio_proj = MultiTalkAudioProjModel( + seq_len=audio_window, + seq_len_vf=audio_window+vae_scale-1, + intermediate_dim=intermediate_dim, + out_dim=out_dim, + context_tokens=context_tokens, + device=device, + dtype=dtype, + operations=operations + ) + self.blocks = torch.nn.ModuleList( + [ + WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d32aad98e024..90deb0077cfc 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -8,9 +8,10 @@ import comfy.clip_vision import json import numpy as np -from typing import Tuple +from typing import Tuple, TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import logging class WanImageToVideo(io.ComfyNode): @classmethod @@ -1288,6 +1289,171 @@ def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io return io.NodeOutput(out_latent) +from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features +class WanInfiniteTalkToVideo(io.ComfyNode): + class DCValues(TypedDict): + mode: str + audio_encoder_output_2: io.AudioEncoderOutput.Type + mask: io.Mask.Type + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanInfiniteTalkToVideo", + category="conditioning/video_models", + inputs=[ + io.DynamicCombo.Input("mode", options=[ + io.DynamicCombo.Option("single_speaker", []), + io.DynamicCombo.Option("two_speakers", [ + io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True), + io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."), + io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."), + ]), + ]), + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.AudioEncoderOutput.Input("audio_encoder_output_1"), + io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."), + io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01), + io.Image.Input("previous_frames", optional=True), + ], + outputs=[ + io.Model.Output(display_name="model"), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_image"), + ], + ) + + @classmethod + def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, + start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput: + + if previous_frames is not None and previous_frames.shape[0] < motion_frame_count: + raise ValueError("Not enough previous frames provided.") + + if mode["mode"] == "two_speakers": + audio_encoder_output_2 = mode["audio_encoder_output_2"] + mask_1 = mode["mask_1"] + mask_2 = mode["mask_2"] + + if audio_encoder_output_2 is not None: + if mask_1 is None or mask_2 is None: + raise ValueError("Masks must be provided if two audio encoder outputs are used.") + + ref_masks = None + if mask_1 is not None and mask_2 is not None: + if audio_encoder_output_2 is None: + raise ValueError("Second audio encoder output must be provided if two masks are used.") + ref_masks = torch.cat([mask_1, mask_2]) + + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + image[:start_image.shape[0]] = start_image + + concat_latent_image = vae.encode(image[:, :, :, :3]) + concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + model_patched = model.clone() + + encoded_audio_list = [] + seq_lengths = [] + + for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]: + if audio_encoder_output is None: + continue + all_layers = audio_encoder_output["encoded_audio_all_layers"] + encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512] + encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512] + encoded_audio_list.append(encoded_audio) + seq_lengths.append(encoded_audio.shape[0]) + + # Pad / combine depending on multi_audio_type + multi_audio_type = "add" + if len(encoded_audio_list) > 1: + if multi_audio_type == "para": + max_len = max(seq_lengths) + padded = [] + for emb in encoded_audio_list: + if emb.shape[0] < max_len: + pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype) + emb = torch.cat([emb, pad], dim=0) + padded.append(emb) + encoded_audio_list = padded + elif multi_audio_type == "add": + total_len = sum(seq_lengths) + full_list = [] + offset = 0 + for emb, seq_len in zip(encoded_audio_list, seq_lengths): + full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype) + full[offset:offset+seq_len] = emb + full_list.append(full) + offset += seq_len + encoded_audio_list = full_list + + token_ref_target_masks = None + if ref_masks is not None: + token_ref_target_masks = torch.nn.functional.interpolate( + ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0] + token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1) + + # when extending from previous frames + if previous_frames is not None: + motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + frame_offset = previous_frames.shape[0] - motion_frame_count + + audio_start = frame_offset + audio_end = audio_start + length + logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}") + + motion_frames_latent = vae.encode(motion_frames[:, :, :, :3]) + trim_image = motion_frame_count + else: + audio_start = trim_image = 0 + audio_end = length + motion_frames_latent = concat_latent_image[:, :, :1] + + audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype()) + model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed + + # add outer sample wrapper + model_patched.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, + "infinite_talk_outer_sample", + InfiniteTalkOuterSampleWrapper( + motion_frames_latent, + model_patch, + is_extend=previous_frames is not None, + )) + # add cross-attention patch + model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch") + if token_ref_target_masks is not None: + model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch") + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1307,6 +1473,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: WanHuMoImageToVideo, WanAnimateToVideo, Wan22ImageToVideoLatent, + WanInfiniteTalkToVideo, ] async def comfy_entrypoint() -> WanExtension: