Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] FP8 grouped gemm kernel without TMA #1483

Open
masahi opened this issue Apr 15, 2024 · 7 comments
Open

[FEA] FP8 grouped gemm kernel without TMA #1483

masahi opened this issue Apr 15, 2024 · 7 comments
Labels
feature request New feature or request inactive-30d

Comments

@masahi
Copy link
Contributor

masahi commented Apr 15, 2024

We want to try a smaller tile size M than 128, which better fits our workload. I'm assuming that the requirement on the tile size M being a multiple of 128 comes from TMA, but for small problem sizes that we encounter in practice, TMA might be an overkill.

According to @hwu36, this is not currently supported.

@masahi masahi added ? - Needs Triage feature request New feature or request labels Apr 15, 2024
@hwu36
Copy link
Collaborator

hwu36 commented Apr 15, 2024

@ANIKET-SHIVAM

@IonThruster
Copy link
Collaborator

BLOCK_M >= 128 requirement likely comes from the fact only cooperative kernel support exists today in CUTLASS 3.5. If support for other kernel schedules is added (tma_warpspecialized or tma_warpspecialized_pingpong) - we can go as low as 64 since FP8 MMA instruction size is 64xNx32.

@masahi : could you share the problem shape for your specific group-gemm for us to better recommend the next steps ?

@thakkarV
Copy link
Collaborator

thakkarV commented Apr 15, 2024

Additionally, for really small values of M, you are likely do be b/w bound anyway, for which you can likely get roofline perf from recompiling CUTLASS 2.x Ampere kernels (with or without stream K)

@masahi
Copy link
Contributor Author

masahi commented Apr 15, 2024

Additionally, for really small values of M, you are likely do be b/w bound anyway, for which you can likely get roofline perf from recompiling CUTLASS 2.x Ampere kernels (with or without stream K)

We did observe that the sm80 kernel is faster for small M. Are you suggesting that the 2.x Ampere kernels compile with FP8?

@masahi
Copy link
Contributor Author

masahi commented Apr 15, 2024

BLOCK_M >= 128 requirement likely comes from the fact only cooperative kernel support exists today in CUTLASS 3.5. If support for other kernel schedules is added (tma_warpspecialized or tma_warpspecialized_pingpong) - we can go as low as 64 since FP8 MMA instruction size is 64xNx32.

@masahi : could you share the problem shape for your specific group-gemm for us to better recommend the next steps ?

I'm working on LLM inference. The problem shape M is dynamic and depends on the number of requests in a batch (let's call it "batch size"). The batch size can be as small as 20, for example, and those 20 requests are dynamically routed to 8 experts (or groups) in Mixtral. So even if all 20 requests are routed to a single expert, the problem size M is 20 (and 0 for all other experts).

The batch size 20 is just an example. Depending on the size of the model and the number of active users, the batch size can range from a few dozen to a few hundred. Such problem shapes might be ridiculously small from your perspective, but we can't always keep the batch size big to minimize latency. BLOCK_M = 64 is still too big according to the description above, but certainly much better.

@ANIKET-SHIVAM
Copy link
Collaborator

Hi @masahi,
you can try to expand the CpAsyncWarpSpecialized Hopper kernels (such as sm90_gemm_warpspecialized_pingpong.hpp to FP8 Grouped Gemm by using components from the TMA Hopper Grouped Gemm kernel.

I believe the Group Tile Scheduler should be easy to plugin: sm90_tile_scheduler_group.hpp

You can then change the pointers to A and B based on the group index: https://github.com/NVIDIA/cutlass/blob/7d49e6c7e2f8896c47f586706e67e1fb215529dc/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp#L344C12-L344C19

And the epilogue from the TMA Hopper example (cutlass::epilogue::PtrArrayNoSmemWarpSpecialized) should work as it is too.

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request inactive-30d
Projects
None yet
Development

No branches or pull requests

6 participants