Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch compile for mixtral #30793

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

zhenglongjiepheonix
Copy link
Contributor

@zhenglongjiepheonix zhenglongjiepheonix commented May 14, 2024

This PR is working in progress and it tries to add torch compile support for Mixtral, it currently also contains changes from #30642 because there are some common ground shared between these two models, and there are several issues regarding Mixtral:

  1. we have to set the following flag to True in order to capture full graph with MOE
torch._dynamo.config.capture_dynamic_output_shape_ops = True

I believe it's inevitable because MistralSparseMoeBlock uses torch.where to extract tokens that each expert cares about, and the number and indexes of tokens that each expert attends to are variable, even if we do make a static shape(which means we zero out the non-care tokens for each expert), we are adding extra computation cost because zero-out values still get to take participate in computation, and each expert will have to run full tokens in terms of computation, which makes the whole point of computation-saving of MOE invalid.

  1. The logits tests on main branch are currently failing on my dev machine
=========================================== short test summary info ===========================================
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralModelTest::test_custom_4d_attention_mask - AssertionError: assert False
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralIntegrationTest::test_small_model_logits - AssertionError: Tensor-likes are not close!
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralIntegrationTest::test_small_model_logits_batched - AssertionError: Tensor-likes are not close!
=========================== 3 failed, 112 passed, 35 skipped, 47 warnings in 34.78s ===========================

@zhenglongjiepheonix zhenglongjiepheonix changed the title Longjie/add torch compile for mixtral Add torch compile for mixtral May 14, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zhenglongjiepheonix zhenglongjiepheonix marked this pull request as draft May 14, 2024 02:53
@amyeroberts
Copy link
Collaborator

cc @ArthurZucker

Comment on lines 851 to 856
# the `top_x` tensor here. this will give `skipping cudagraphs due to index put with accumulate`
# in compile
# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

# still suffers from `skipping cudagraphs due to ['incompatible ops']`
final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am kind of stuck on this, here it seems to give cudagraph skipped warnings no matter what equivalent form I put, for now it seems that cudagraphs can only be applied partially because of this, I have tried the following forms:

  1. final_hidden_states.index_add_
    this will give skipping cudagraphs due to index put with accumulate
  2. final_hidden_states[top_x] += ...
    this will give skipping cudagraphs due to ['incompatible ops']
  3. final_hidden_states.scatter_add_...
    this will disable fullgraph tracing because data dependent ops on top_x

I think the root cause still comes from the dynamic nature of moe where different experts compute different sets of tokens, and it seems that we can not circumvent index put if we do not want every expert to do a full
forward with all tokens @ArthurZucker @gante do you have any thoughts or suggestions on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is expected pretty much yes!
If we use the megablock like implementation (with sparse topology and matrix reprensentation) like it was done in JetMoE we might be able to get over this, but not sure we can go further with the current version!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is expected pretty much yes! If we use the megablock like implementation (with sparse topology and matrix reprensentation) like it was done in JetMoE we might be able to get over this, but not sure we can go further with the current version!

Yes, the root cause is top_x here we use is unbacked free symbols in torch.compile and is data dependent beucase of torch.where, this will cause skipped cudagraphs, but we will still benefit from partial cudagraphs if we are not rewriting it into sparse forms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, currently torch.compile produces wrong results when setting fullgraph=True, I believe it has something to do with torch.where used here(when I try ignore expert mask and compute the whole token set for every expert results can align with eager forward), the traced fx graph is not correct, I think if we want to support torch.compile in fullgraph mode we have to rewrite moe layer in a whole different way, maybe compute experts for tokens rather than compute tokens for experts @ArthurZucker

Comment on lines +789 to +862
class MixtralBlockTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.num_experts = config.num_local_experts
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size

self.w1 = nn.Parameter(torch.empty(self.num_experts, self.ffn_dim, self.hidden_dim))
self.w2 = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.ffn_dim))
self.w3 = nn.Parameter(torch.empty(self.num_experts, self.ffn_dim, self.hidden_dim))

self.act_fn = ACT2FN[config.hidden_act]

def forward(
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
) -> torch.Tensor:
"""_summary_

Args:
hidden_states (torch.Tensor): (batch_size * token_num, hidden_dim)
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
routing_weights (torch.Tensor): (batch_size * token_num, top_k)

Returns:
torch.Tensor: _description_
"""

ts, tk = hidden_states.size(0), selected_experts.size(-1)

w1 = self.w1[selected_experts] # (batch_size * token_num, top_k, ffn_dim, hidden_dim)
w2 = self.w2[selected_experts] # (batch_size * token_num, top_k, hidden_dim, ffn_dim)
w3 = self.w3[selected_experts] # (batch_size * token_num, ffn_dim, hidden_dim)

x1 = torch.matmul(w1, hidden_states[:, None, :, None])
x3 = torch.matmul(w3, hidden_states[:, None, :, None])
x1 = self.act_fn(x1)
final_hidden_states = torch.matmul(w2, x1 * x3).reshape(ts, tk, self.hidden_dim)
final_hidden_states = final_hidden_states * routing_weights[:, :, None]
final_hidden_states = final_hidden_states.sum(dim=1)
return final_hidden_states


class MixtralMoeBlock(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok

# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = MixtralBlockTop2MLP(config)
# Jitter parameters
self.jitter_noise = config.router_jitter_noise

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gathers experts for tokens, and it actually works for torch.compile with fullgraph and cudagraphs support, and I think it works best when we are doing decoding phase where the batchsize is small, but it will uses more memory because we need to gather expert weights for every token, however it will require changes on model weights structure when loading (from expert-wise scattered MLPs to a centralized MLP)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are supporting fast generation, then I think it's good to have this than the current version because we definitely will gain more speedups especially when decoding @ArthurZucker

@ArthurZucker ArthurZucker self-requested a review May 23, 2024 13:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants