Skip to content

add fa4 and refactor spas attn#962

Open
STwangyingrui wants to merge 3 commits intomainfrom
yr/add_fa4_and_refactor_spas_attn
Open

add fa4 and refactor spas attn#962
STwangyingrui wants to merge 3 commits intomainfrom
yr/add_fa4_and_refactor_spas_attn

Conversation

@STwangyingrui
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several new sparse attention mechanisms and operators, including FlashAttention v4 and SageAttention v2/v3, along with corresponding configuration files and utility functions. Key changes include the addition of SparseFlashAttn4Weight, SparseSageAttn2Weight, and SparseSageAttn3Weight, as well as a new SpargeMaskGenerator and a comprehensive sparge_util.py containing Triton kernels for block map generation. Feedback focuses on improving error handling by raising exceptions for unsupported sparse modes instead of just logging them, and refining assertion messages for better clarity and professionalism.

smooth_k = kt - kt.mean(dim=-2, keepdim=True)
sparse_map = get_block_map_meansim(qt, smooth_k, cdfthreshd=None, topk=self.topk, return_lut=False, BLKQ=self.BLKQ, BLKK=self.BLKK)
else:
logger.info(f"spas_flash_attn4 sparse_mode only support sla_mode and sparge_mode now.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using logger.info for an unsupported sparse_mode might lead to silent failures or misconfigurations being overlooked. It would be more robust to either raise a ValueError or use logger.error to clearly indicate an invalid state.

Suggested change
logger.info(f"spas_flash_attn4 sparse_mode only support sla_mode and sparge_mode now.")
raise ValueError(f"Unsupported sparse_mode: {self.sparse_mode}. spas_flash_attn4 sparse_mode only supports 'sla_mode' and 'sparge_mode'.")

smooth_k = k - k.mean(dim=-2, keepdim=True)
sparse_map = get_block_map_meansim(q, smooth_k, cdfthreshd=None, topk=self.topk, return_lut=False, BLKQ=self.BLKQ, BLKK=self.BLKK)
else:
logger.info(f"spas_sage_attn2 sparse_mode only support sla_mode and sparge_mode now.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using logger.info for an unsupported sparse_mode might lead to silent failures or misconfigurations being overlooked. It would be more robust to either raise a ValueError or use logger.error to clearly indicate an invalid state.

Suggested change
logger.info(f"spas_sage_attn2 sparse_mode only support sla_mode and sparge_mode now.")
raise ValueError(f"Unsupported sparse_mode: {self.sparse_mode}. spas_sage_attn2 sparse_mode only supports 'sla_mode' and 'sparge_mode'.")

smooth_k = k - k.mean(dim=-2, keepdim=True)
sparse_map = get_block_map_meansim(q, smooth_k, cdfthreshd=None, topk=self.topk, return_lut=False, BLKQ=self.BLKQ, BLKK=self.BLKK)
else:
logger.info(f"spas_sage_attn3 sparse_mode only support sla_mode and sparge_mode now.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using logger.info for an unsupported sparse_mode might lead to silent failures or misconfigurations being overlooked. It would be more robust to either raise a ValueError or use logger.error to clearly indicate an invalid state.

Suggested change
logger.info(f"spas_sage_attn3 sparse_mode only support sla_mode and sparge_mode now.")
raise ValueError(f"Unsupported sparse_mode: {self.sparse_mode}. spas_sage_attn3 sparse_mode only supports 'sla_mode' and 'sparge_mode'.")

q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
assert bs == 1, "flash_attn4 doesn't support flash_attn_varlen_func now. Just use it for batchsize = 1 for sure."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message here is a bit informal. Consider making it more professional to clearly communicate the limitation to users or developers.

Suggested change
assert bs == 1, "flash_attn4 doesn't support flash_attn_varlen_func now. Just use it for batchsize = 1 for sure."
assert bs == 1, "FlashAttention v4 currently only supports batch size of 1 for this function."

q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
assert bs == 1, "flash_attn4 doesn't support flash_attn_varlen_func now. Just use it for batchsize = 1 for sure."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message here is a bit informal. Consider making it more professional to clearly communicate the limitation to users or developers.

Suggested change
assert bs == 1, "flash_attn4 doesn't support flash_attn_varlen_func now. Just use it for batchsize = 1 for sure."
assert bs == 1, "FlashAttention v4 currently only supports batch size of 1 for this function."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant