Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 74 additions & 17 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
from typing import Optional, Any, Tuple
import math

from comfy.ldm.modules.attention import optimized_attention_for_device
Expand Down Expand Up @@ -32,6 +32,7 @@ class Llama2Config:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Mistral3Small24BConfig:
Expand All @@ -54,6 +55,7 @@ class Mistral3Small24BConfig:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Qwen25_3BConfig:
Expand All @@ -76,6 +78,7 @@ class Qwen25_3BConfig:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Qwen3_06BConfig:
Expand All @@ -98,6 +101,7 @@ class Qwen3_06BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Qwen3_4BConfig:
Expand All @@ -120,6 +124,7 @@ class Qwen3_4BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Qwen3_8BConfig:
Expand All @@ -142,6 +147,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Ovis25_2BConfig:
Expand All @@ -164,6 +170,7 @@ class Ovis25_2BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Qwen25_7BVLI_Config:
Expand All @@ -186,6 +193,7 @@ class Qwen25_7BVLI_Config:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Gemma2_2B_Config:
Expand All @@ -209,6 +217,7 @@ class Gemma2_2B_Config:
sliding_attention = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False

@dataclass
class Gemma3_4B_Config:
Expand All @@ -232,6 +241,7 @@ class Gemma3_4B_Config:
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False

@dataclass
class Gemma3_12B_Config:
Expand All @@ -255,6 +265,7 @@ class Gemma3_12B_Config:
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
mm_tokens_per_image = 256

Expand Down Expand Up @@ -356,6 +367,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
Expand All @@ -373,11 +385,30 @@ def forward(

xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)

present_key_value = None
if past_key_value is not None:
index = 0
num_tokens = xk.shape[2]
if len(past_key_value) > 0:
past_key, past_value, index = past_key_value
if past_key.shape[2] >= (index + num_tokens):
past_key[:, :, index:index + xk.shape[2]] = xk
past_value[:, :, index:index + xv.shape[2]] = xv
xk = past_key[:, :, :index + xk.shape[2]]
xv = past_value[:, :, :index + xv.shape[2]]
present_key_value = (past_key, past_value, index + num_tokens)
else:
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
else:
present_key_value = (xk, xv, index + num_tokens)

xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)

output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
return self.o_proj(output)
return self.o_proj(output), present_key_value

class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
Expand Down Expand Up @@ -408,15 +439,17 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
x, present_key_value = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
)
x = residual + x

Expand All @@ -426,7 +459,7 @@ def forward(
x = self.mlp(x)
x = residual + x

return x
return x, present_key_value

class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
Expand All @@ -451,6 +484,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
Expand All @@ -468,11 +502,12 @@ def forward(
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
x, present_key_value = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
)

x = self.post_attention_layernorm(x)
Expand All @@ -485,7 +520,7 @@ def forward(
x = self.post_feedforward_layernorm(x)
x = residual + x

return x
return x, present_key_value

class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
Expand Down Expand Up @@ -516,9 +551,10 @@ def __init__(self, config, device=None, dtype=None, ops=None):
else:
self.norm = None

# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)

def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None:
x = embeds
else:
Expand All @@ -527,8 +563,13 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
if self.normalize_in:
x *= self.config.hidden_size ** 0.5

seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = past_key_values[0][2]

if position_ids is None:
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)

freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
Expand All @@ -539,14 +580,16 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed

mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))

causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
if seq_len > 1:
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask

optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)

intermediate = None
Expand All @@ -562,16 +605,27 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output

next_key_values = []
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(

past_kv = None
if past_key_values is not None:
past_kv = past_key_values[i] if len(past_key_values) > 0 else []

x, current_kv = layer(
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_kv,
)

if current_kv is not None:
next_key_values.append(current_kv)

if i == intermediate_output:
intermediate = x.clone()

Expand All @@ -588,7 +642,10 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)

return x, intermediate
if len(next_key_values) > 0:
return x, intermediate, next_key_values
else:
return x, intermediate


class Gemma3MultiModalProjector(torch.nn.Module):
Expand Down
2 changes: 2 additions & 0 deletions comfy_api/latest/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ class Hidden(str, Enum):
class NodeInfoV1:
input: dict=None
input_order: dict[str, list[str]]=None
is_input_list: bool=None
output: list[str]=None
output_is_list: list[bool]=None
output_name: list[str]=None
Expand Down Expand Up @@ -1474,6 +1475,7 @@ def get_v1_info(self, cls) -> NodeInfoV1:
info = NodeInfoV1(
input=input,
input_order={key: list(value.keys()) for (key, value) in input.items()},
is_input_list=self.is_input_list,
output=output,
output_is_list=output_is_list,
output_name=output_name,
Expand Down
1 change: 1 addition & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def node_info(node_class):
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['is_input_list'] = getattr(obj_class, "INPUT_IS_LIST", False)
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
Expand Down
Loading