Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Kernel] Support bitsandbytes quantization and QLoRA #4776

Merged
merged 13 commits into from
Jun 1, 2024

Conversation

chenqianfzh
Copy link
Contributor

@chenqianfzh chenqianfzh commented May 12, 2024

QLoRA (https://arxiv.org/abs/2305.14314) cuts memory consumption in LLM weight loading without degrading performance. The weights of the basic model , which are quantized into 4 bit using bitsandbytes quantization, pair with a low-rank but higher-precision Low-Rank weight matrix to generate output.

This MR is the first step in supporting QLoRA in vLLM. With the PR, the Qlora author's open model on hugging face, such as, is supported:

User can run with or without a QLoRA adapter.

So far, only llama as a basic model is supported. More to come in the future. As explained below, special consideration is made for extensibility to future changes and other models.
Also, no TP or PP with QLoRA is supported. It will be considered as the immediate next effort.

Explanation on Changes

Modified files mainly include

  • Modify vllm/config.py, vllm/engine/arg_utils.py: Add new CLI parameters for QLoRA/bitsandbytes. The new parameter is:
    • qlora_adapter_name_or_path : the path to the adpater repo. Could be empty.
  • Modify vllm/model_loader/loader.py: Define a new loader class, which will quantize the weight using bitsandbytes during loading
  • Modify vllm/model_executor/layers/linear.py: Add the logic of concatenate tensor in bitsandbytes in the weight_loader () function of QKVParallelLinear class and MergedColumnParallelLinear class.

The newly added files are:

  • VLLM/model_executor/layers/quantization/bitsandbytes.py: Here, similar to other quantization methods, we define two classes, class BitsAndBytesConfig (QuantizationConfig) and class BitsAndBytesLinearMethod (LinearMethodBase).
  • Examples/qlora_inference.py: Demonstration of the use of bitsandbytes, both with and without an adapter.

@jeejeelee
Copy link
Contributor

ping @Yard1

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This completely bypasses the existing LoRA logic and implements its own. I don't think this is a good design and it clashes with already existing code. We should instead modify the LoRA support already present in vLLM to support QLoRA - it should also allow us to reuse a lot of existing code.

@chenqianfzh
Copy link
Contributor Author

This completely bypasses the existing LoRA logic and implements its own. I don't think this is a good design and it clashes with already existing code. We should instead modify the LoRA support already present in vLLM to support QLoRA - it should also allow us to reuse a lot of existing code.

Thanks for your reply. You are not the first one who popped this concern. Actually, I asked myself the same question. :-)

I considered about re-use of LoRA at the first place. I have to start a new set of code because:

  1. the existing LoRA in vLLM is implementing punica (https://github.com/punica-ai/punica), a multi-tenant scenario of LoRA. A lot of effort have made on the LoRA manager, which manages the cases where different sets of fine-tune weights using the same basic models.

But QLoRA, though carring a very similar name, work for a totally different scenario, thus unable to re-use the existing code of LoRA in vLLM.

  1. punica is based on cuda code of BGMV, and BGMV does not support any quantization. But in QLORA, quantization of basic model is the keypoint in saving memory. This is another reason I had to deviate away from reusing LoRA.

  2. On the other hand, QLoRA use a different set of cuda code. The author of QLORA provides the Cuda implementation of QLORA implemention and packed in the python package of bitsandbytes, which is used in the QLORA implementation of huggingface transformers package. So I moved away from re-using the LoRA code.

How about I add some comments somewhere to clarify your concern?

@Yard1
Copy link
Collaborator

Yard1 commented May 13, 2024

Is it theoretically possible for the QLoRA adapter to be loaded and unloaded at will?

@chenqianfzh
Copy link
Contributor Author

Is it theoretically possible for the QLoRA adapter to be loaded and unloaded at will?

I am not sure what you mean by "at will". Do you mean load/unload during runtime?

In this implementation, user can load an adpater by specfiying "qlora_adapter_name_or_path" in parameter when starting the inference. User can also run without an adapter by leaving the above parameter empty.

However, the user cannot switch the adapter during the runtime. Switching adapter is not a scenario supported in the QLoRA design.

The main goal of QLoRA is to use to LoRA weights to compensate the loss caused by the 4-bit quantization in the basic model. So it is a quantization technique. Switching LoRA to support different fine-tune scenarios, as in punica, is not in its design goals.

@Yard1
Copy link
Collaborator

Yard1 commented May 13, 2024

Ok, that's what I wanted to confirm. Thanks for clearing it up. In that case:

  1. for consistency, I would suggest ditching the qlora_supported decorator and just specify the class attribute directly on the class
  2. we should avoid the if model_config.quantization == "qlora": pattern in linear layer and weight loading code - instead we should use abstractions (and add them if they are missing). For example, we should add a QLoRAModelLoader which can subclass/compose DefaultModelLoader. Same for linear layer - we should avoid adding special cases to generic implementations (I understand this pattern is not always followed in the codebase, but we should hold new code to higher standard - happy to discuss what sort of API we need to add to get rid of the Special case for Quantized Weights. in linear layer implementation)

@chenqianfzh
Copy link
Contributor Author

Ok, that's what I wanted to confirm. Thanks for clearing it up. In that case:

  1. for consistency, I would suggest ditching the qlora_supported decorator and just specify the class attribute directly on the class
  2. we should avoid the if model_config.quantization == "qlora": pattern in linear layer and weight loading code - instead we should use abstractions (and add them if they are missing). For example, we should add a QLoRAModelLoader which can subclass/compose DefaultModelLoader. Same for linear layer - we should avoid adding special cases to generic implementations (I understand this pattern is not always followed in the codebase, but we should hold new code to higher standard - happy to discuss what sort of API we need to add to get rid of the Special case for Quantized Weights. in linear layer implementation)

Thanks for the suggestion. I will make the changes as suggested.

Cheers!

@jeejeelee
Copy link
Contributor

Thank you for your excellent work. Here are some personal opinions:

  • vLLM has supported quantized models with LoRA, refer to quant model+lora. These can be generalized as QLoRA (e.g., GPTQ+LoRA), and all of them support switching adapters.
  • For the original QLoRA (https://arxiv.org/abs/2305.14314), I think we should add a new quantization method named bitsandbytes (e.g., BAB+LoRA), refer to [Feature]: bitsandbytes support #4033, and then we can reuse the current LoRA logic.
  • Regardless of LoRA or QLoRA, Punica can support these

If I am wrong, please correct me directly, Thanks again.

Cheers!

@chenqianfzh
Copy link
Contributor Author

chenqianfzh commented May 14, 2024

Thank you for your excellent work. Here are some personal opinions:

  • vLLM has supported quantized models with LoRA, refer to quant model+lora. These can be generalized as QLoRA (e.g., GPTQ+LoRA), and all of them support switching adapters.
  • For the original QLoRA (https://arxiv.org/abs/2305.14314), I think we should add a new quantization method named bitsandbytes (e.g., BAB+LoRA), refer to [Feature]: bitsandbytes support #4033, and then we can reuse the current LoRA logic.
  • Regardless of LoRA or QLoRA, Punica can support these

If I am wrong, please correct me directly, Thanks again.

Cheers!

I re-read the LoRA code carefully and saw that quantization is supported in LoRA now. It was not supported when I started my design and coding. Sorry for the miss.

I will re-think my design again based on this change, as well as Yard1's suggestions.

Thanks & Happy Coding!

@chenqianfzh
Copy link
Contributor Author

@Yard1 @jeejeelee

I just updated the MR of QLoRA/BitsAndBytes with the changes suggested. Could you please take another look?

Thanks for the great advice from you. Learned a lot and improved a lot. :-)

BTW, I hit a lot of yapf errors in CI/CD. I found the the yapf errors are not from me. Should I just ignore it?

@jeejeelee
Copy link
Contributor

@chenqianfzh We cannot igore format error, you can run bash format.sh to check for format errors

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is looking much cleaner! Left some comments, hope they will be useful.

vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/bitsandbytes.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/bitsandbytes.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/bitsandbytes.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
@Yard1
Copy link
Collaborator

Yard1 commented May 23, 2024

We should also add a test for this - it's ok if it's just an end to end one (load a small model from huggingface hub and see if it works and gives good outputs)

requirements-common.txt Outdated Show resolved Hide resolved
@chenqianfzh
Copy link
Contributor Author

@mgoin @Yard1 @jeejeelee

Thanks for the feedback. Working on the changes now.

@chenqianfzh
Copy link
Contributor Author

We should also add a test for this - it's ok if it's just an end to end one (load a small model from huggingface hub and see if it works and gives good outputs)

the newly added file examples/qlora_inference.py is created for this purpose. In this file, both the case that bitsandbytes quantization with/withou LoRA adpaters are tested.

Here are the ouput I got in my local test ( of the four, the last is without a LORA adapter , the other three are with adpaters:

--------------------------------------------------------------------------
Prompt: The capital of France is 
Output:  Paris.
--------------------------------------------------------------------------
Prompt: The capital of USA is 
Output:  Washington DC.
--------------------------------------------------------------------------
Prompt: my name is 
Output:  john and i am a 20 year old male. i am a student at the university of maryland. i am a sophomore and i am majoring in business. i am a very outgoing person and i love to meet new people. i am a very social person and i love to party. i am a very outgoing person and i love to meet new people. i am a very social person and i love to party.
--------------------------------------------------------------------------
Prompt: My name is 
Output:  Kyle and I am a 20 year old college student. I am a huge fan of the outdoors and love to hike, camp, and fish. I am a very active person and love to stay busy. I am a very outgoing person and love to meet new people. I am a very easy going person and love to have fun. I am a very hard worker and love to work. I am a very trustworthy person and love to help people. I am a very caring person and love to help people. I am a very respectful person and love to respect others. I am a

@Yard1
Copy link
Collaborator

Yard1 commented May 24, 2024

@chenqianfzh example is fine, but we need an automated pytest test to run in CI to prevent regressions.

@jeejeelee
Copy link
Contributor

@chenqianfzh Can we add more quantization type examples in qlora_example.py, such as GPT+LoRA, so that users can refer to this script to learn how to utilize LoRA on quantized model, thanks

@chenqianfzh chenqianfzh force-pushed the qian/qlora branch 8 times, most recently from 523c053 to 0ab5879 Compare May 28, 2024 06:04
@Yard1
Copy link
Collaborator

Yard1 commented May 29, 2024

@chenqianfzh the merge commit is expected, that's just how git works

@chenqianfzh
Copy link
Contributor Author

chenqianfzh commented May 29, 2024

@chenqianfzh the merge commit is expected, that's just how git works

I did something wrong in squashing commits before merging, so the commits are mixed. Sorry to make your review more difficult. :-(

vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left two last nits! We can merge after those are resolved.

@chenqianfzh
Copy link
Contributor Author

Thanks, left two last nits! We can merge after those are resolved.

I've updated the code based on your feedback and have omitted one comment, for which I've provided an explanation. Could you please take a look?

thanks.

@chenqianfzh
Copy link
Contributor Author

@Yard1 I kept trying the CI tests in the past two days. But hit all kinds of weird errors, like the latest failure is due to a container missing in AMD tests.

I did not find a way to restart the specific tests. Could you let me know what to do? Thanks.

@Yard1
Copy link
Collaborator

Yard1 commented May 31, 2024

It's OK, we'll just have a maintainer force merge it. Can you resolve #4776 (comment) and I will accept

examples/offline_inference.py Outdated Show resolved Hide resolved
@mgoin mgoin changed the title support QLoRA [Feature][Kernel] Support bitsandbytes quantization and QLoRA Jun 1, 2024
vllm/config.py Show resolved Hide resolved
vllm/model_executor/layers/quantization/bitsandbytes.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/bitsandbytes.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
@chenqianfzh
Copy link
Contributor Author

@mgoin Thanks for reviewing the PR!

I updated the code per your comments. Could u have another check?

@mgoin mgoin merged commit b9c0605 into vllm-project:main Jun 1, 2024
65 checks passed
@XiaoningDing XiaoningDing mentioned this pull request Jun 4, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants