Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Gemma work with torch.compile #30775

Merged
merged 12 commits into from
May 16, 2024
Merged

Make Gemma work with torch.compile #30775

merged 12 commits into from
May 16, 2024

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented May 13, 2024

What does this PR do?

Currently on main, Gemma can't work with torch.compile (with static cache of course).
This PR fixes it.

If the change is approved, I will apply it for a few more models to pass the copy check.

Short error log (on A10, torch 2.3 + cu112)

RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 1113, in forward

  File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 113, in forward
    self.inv_freq = 1.0 / (. To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

To reproduce and full error log

code snippet
import os
import torch
import datetime

from transformers import AutoTokenizer, AutoModelForCausalLM

token = "ADD_YOUR_OWN_TOKEN"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

batch_size = 1
n_iter = 3

ckpt = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(ckpt, token=token)
model = AutoModelForCausalLM.from_pretrained(ckpt, token=token, torch_dtype=torch.float16).to("cuda")

model.generation_config.max_new_tokens = 16
model.generation_config.max_new_tokens = 16

model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

input_text = "Why dogs are cute."
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to("cuda")

for i in range(n_iter):
    s = datetime.datetime.now()
    outputs = model.generate(**input_ids, do_sample=False)
    t = datetime.datetime.now()
    e = (t-s).total_seconds()
    print(e)
Full error log

Full error log

Traceback (most recent call last):
  File "temp.py", line 38, in <module>
    outputs = model.generate(**input_ids, do_sample=False)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/transformers/src/transformers/generation/utils.py", line 1679, in generate
    result = self._greedy_search(
  File "/transformers/src/transformers/generation/utils.py", line 2342, in _greedy_search
    outputs = self(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 1065, in forward
    @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 917, in forward
    return compiled_fn(full_args)
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 106, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 152, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/codecache.py", line 906, in __call__
    return self.get_current_callable()(inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py", line 838, in run
    return compiled_fn(new_inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/cudagraph_trees.py", line 383, in deferred_cudagraphify
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/cudagraph_trees.py", line 411, in cudagraphify
    return manager.add_function(
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/cudagraph_trees.py", line 1943, in add_function
    return fn, fn(inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/cudagraph_trees.py", line 1757, in run
    out = self._run(new_inputs, function_id)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/cudagraph_trees.py", line 1798, in _run
    return self.run_eager(new_inputs, function_id)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/cudagraph_trees.py", line 1913, in run_eager
    return node.run(new_inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/cudagraph_trees.py", line 605, in run
    non_cudagraph_inps = get_non_cudagraph_inps()
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/cudagraph_trees.py", line 600, in get_non_cudagraph_inps
    and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 1113, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 909, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 646, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 549, in forward
    cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 113, in forward
    self.inv_freq = 1.0 / (. To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -104,15 +104,16 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope there is no specific (important enough) reason to do so.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is 😅 precision issues. There is no flag to keep buffers in float32 precision, and I think it makes a difference in terms of the compute inv freq. That is what needs to be checked

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we want to keep it always in fp32, but with buffer + if we give it values at init, it could be changed to other dtype at later stage of from_pretrained?

Copy link
Collaborator Author

@ydshieh ydshieh May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the case, I would suggest let's first reach to an agreement that the issue of current implementation of self.inv_freq doesn't work (actually) with torch.compile

It would be nice if you can try the provided code snippet. And if you want some evidences from the test test_torch_compile_fullgraph, I can work on that too.

Then we discuss what would be a better fix.

@ydshieh ydshieh marked this pull request as ready for review May 13, 2024 13:10
@ydshieh ydshieh requested a review from ArthurZucker May 13, 2024 13:10
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems strange, as we do tests torch compile with gemma no? Is this a torch version?
AFAIK gemma supported compile from day 0 and I could use static cache prior to this PR.

@@ -104,15 +104,16 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is 😅 precision issues. There is no flag to keep buffers in float32 precision, and I think it makes a difference in terms of the compute inv freq. That is what needs to be checked

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 13, 2024

@ArthurZucker

I see test_torch_compile_fullgraph and inside it it uses

torch.compile(model, fullgraph=True)

As mentioned (1 or 2 weeks ago) earlier on slack, this would likely pass without actually compiling the stuff.

@ArthurZucker
Copy link
Collaborator

but then we run forward with it

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 13, 2024

AFAIK gemma supported compile from day 0 and I could use static cache prior to this PR.

I would like to try your code and know your environment.

In any case, I am sure there is something wrong on current main branch and reproducible code snippet are provided in the description.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 13, 2024

but then we run forward with it

It doesn't capture the issue. My original message

However it always gives ~10 seconds no matter what compilation of static and compile I specify.

So we really need to compile the forward instead of model to have the compilation actually happen. And only that happens, we will see if the compilation works or fails

This means we could still run (even if we torch.compile(model) instead of torch.compile(model.forward), but under the hood, I believe compile is not actually doing anything.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 13, 2024

also, the test is only on CPU. So we probably miss the cuda cases, as the error I provided are cuda-related

RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/transformers/src/transformers/models/gemma/modeling_gemma.py", line 1113, in forward

To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 14, 2024

@ArthurZucker

Taking the code snippet in the PR description

On A10, torch 2.3+cu121:

  • on main

  • compile model.forward:

    • use device=cpu --> failed (other error however)
    • use device=cuda --> failed (self.inv_freq issue)
  • compile model:

    • use device=cuda --> could run, but no speed up from compilation over non-compilation (there is no compilation as it should take longer for the first 2 iterations but it didn't!)
      • (this is the same with this PR - we should compile forward !)
  • on this PR

    • use device=cpu --> failed (other error however)
    • use device=**cuda** --> works + we do have speed up from the compilation (tested with longer sequences)

Regarding the test test_torch_compile_fullgraph, it is a method of AttentionMaskTester class, which is not inherited by any model specific test class. It prepares a dummy model (see below). So no model class, including gemma, are tested by this test.

Dummy model used in test_torch_compile_fullgraph

    class Prepare4dCausalAttentionMaskModel(nn.Module):
        def forward(self, inputs_embeds):
            batch_size, seq_length, _ = inputs_embeds.shape
            past_key_values_length = 4
            attention_mask = _prepare_4d_causal_attention_mask(
                None, (batch_size, seq_length), inputs_embeds, past_key_values_length
            )
            return attention_mask

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 14, 2024

BTW, Llama use exactly what this PR does. It doesn't have the precision issue you mentioned above?

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test similar to test_compile_static_cache from Llama and now we should use copied from llama for the rotary embedding no?

@ydshieh ydshieh marked this pull request as draft May 14, 2024 15:26
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 14, 2024

Done

Draft model while adding a test

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 14, 2024

Adding test in this commit, see comments I left there for some explanations

a63c3ba

@ydshieh ydshieh marked this pull request as ready for review May 14, 2024 16:13
@ydshieh ydshieh marked this pull request as draft May 14, 2024 16:13
@ydshieh ydshieh marked this pull request as ready for review May 15, 2024 06:02
@ydshieh ydshieh requested a review from ArthurZucker May 15, 2024 06:02
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
In order to make absolutely sure compile works, we need a test on the generations for gemma .

tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/models/llama/test_modeling_llama.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 15, 2024

In order to make absolutely sure compile works, we need a test on the generations for gemma .

I will add test_compile_static_cache

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 15, 2024

test_compile_static_cache added. I am not able to run it with T4 with torch 2.3 with cuda 118. It has compile issue.
Let's focus on the A10 for now maybe and come back to T4 once we move our daily CI to torch 2.3.

@ydshieh ydshieh requested a review from ArthurZucker May 15, 2024 15:32
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 15, 2024

All addressed, just the copies to be updated.

self.register_buffer("inv_freq", None, persistent=False)

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as long as the dtype is float32 here this works!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i will check again what you shared on slack DM and apply it. Thanks for the review!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed it's float32 even if i set torch_dtype=torch.float16 in from_pretrained

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, after apply the fix-copies

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 16, 2024

Merge as the slow CI failures are on main too.

@ydshieh ydshieh merged commit 1b3dba9 into main May 16, 2024
20 of 30 checks passed
@ydshieh ydshieh deleted the fix_gemma_torch_compile branch May 16, 2024 11:41
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 16, 2024
* fix

* [run-slow] gemma

* add test

* add `test_compile_static_cache`

* fix

* style

* remove subprocess

* use attribute

* fix

* style

* update

* [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
itazap pushed a commit that referenced this pull request May 24, 2024
* fix

* [run-slow] gemma

* add test

* add `test_compile_static_cache`

* fix

* style

* remove subprocess

* use attribute

* fix

* style

* update

* [run-slow] dbrx,gemma,jetmoe,phi3,recurrent_gemma

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants