You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I implemented Coati Lora before parallel fine-tuning for LlaMA-7B, and found:
Gemini runs into Error(s) in loading state_dict for GeminiCheckpointIO: and Train params remained 6.32 B
HybridParallel(pp=1) runs into RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn but Train params are set correctly and equals to 38.68 M
HybridParallel(pp=2) ran successfully and Train params are divided properly, 19.06 M on master GPU
Considering the efficiency and stability in fine-tuning large models and viability to supply longer seq_len and larger batch_size, I'm sincerely looking forward to a recent update to fully support LoRA in distributed training/fine-tuning.
requirements colossalai=0.35, loralib=0.1.2, transformer=4.33 . I'm also using flash-attn=2.5.6&dropout-layer-norm=0.1(submodule of flash-atto) with a few verifications to shardformer/modeling/llama.py to implement flash attention for HybridParallel plugin.
modifications
finetune.py under model loading part:
withinit_ctx:
model=LlamaForCausalLM(config)
ifargs.lora:
fromcoati_loraimportconvert_to_lora_module# coati_lora is the lora.py copied from Coatimodel=convert_to_lora_module(model, 16)
finetune.py under arg_parser:
parser.add_argument("--lora", action="store_true")
parser.add_argument("--ppsize", default=2, type=int)
parser.add_argument("--tpsize", default=4, type=int)
# Gemini is left unchanged but HybridParallel had modificationsifargs.plugin=="hybrid_parallel":
# modify the param accordingly, default configuration is for llama2-7b# The pptp_size below is an parameter to control DataParallel# and does not matter hereargs.pptp_size=args.ppsize*args.tpsizeplugin=HybridParallelPlugin(
tp_size=args.tpsize,
pp_size=args.ppsize,
num_microbatches=2, microbatch_size=None,
enable_jit_fused=False, zero_stage=0,
precision="bf16", initial_scale=1,
)
finetune.sh re-written another version:
MODEL_NAME="deepseek-coder-6.7b-instruct"
DATASET_PATH=""
SAVE_DIR="save_checkpoint/$MODEL_NAME"# LoRA# Notice that I did not use DataParallel here
CUDA_VISIBLE_DEVICES=3,5 CUDA_LAUNCH_BLOCKING=1 \
nohup colossalai run --nproc_per_node 2 --master_port 29503 \
col_train.py --plugin "gemini" \
--model_path "./model/$MODEL_NAME" --dataset "$DATASET_PATH" \
--save_dir $SAVE_DIR --save_interval 5000 \
--lr 0.00005 --lora --batch_size 2 --max_length 2048 --ppsize 2 --tpsize 1 \
--mixed_precision bf16 --flash_attention \
--tensorboard_dir "log/train/tb_logs" \
> log/train/[$$]${MODEL_NAME}.log &
Expected Behavior
Considering the efficiency and stability in fine-tuning large models and viability to supply longer seq_len and larger batch_size, I'm sincerely looking forward to a recent update to fully support LoRA in distributed training/fine-tuning. Specifically, I demand
Making Coati Lora compatible with HybridParallel plugin when pp_size=1
Making Coati Lora compatible with Gemini plugin
Further support Peft in distributed training/fine-tuning, making it compatible with Gemini and HybridParallel and even flash-attn
It turn outs to be a problem with tensor parallel in hybrid_parallel plugin, where lora parameters are ignored when building column&row parallel layers.
馃悰 Describe the bug
Description
I implemented
Coati Lora
before parallel fine-tuning for LlaMA-7B, and found:Gemini
runs into Error(s) in loading state_dict for GeminiCheckpointIO: and Train params remained 6.32 BHybridParallel(pp=1)
runs into RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn but Train params are set correctly and equals to 38.68 MHybridParallel(pp=2)
ran successfully and Train params are divided properly, 19.06 M on master GPUConsidering the efficiency and stability in fine-tuning large models and viability to supply longer seq_len and larger batch_size, I'm sincerely looking forward to a recent update to fully support LoRA in distributed training/fine-tuning.
To Reproduce
environment
CUDA=11.7, torch=2.1.2cu118, 2*A100-40G, NCCL backend, python=3.10.14, Ubuntu20.04requirements
colossalai=0.35, loralib=0.1.2, transformer=4.33 . I'm also using flash-attn=2.5.6&dropout-layer-norm=0.1(submodule of flash-atto) with a few verifications to shardformer/modeling/llama.py to implement flash attention for HybridParallel plugin.modifications
finetune.py
under model loading part:finetune.py
under arg_parser:finetune.sh
re-written another version:Expected Behavior
Considering the efficiency and stability in fine-tuning large models and viability to supply longer seq_len and larger batch_size, I'm sincerely looking forward to a recent update to fully support LoRA in distributed training/fine-tuning. Specifically, I demand
Making
Coati Lora
compatible withHybridParallel
plugin whenpp_size=1
Making
Coati Lora
compatible withGemini
pluginFurther support
Peft
in distributed training/fine-tuning, making it compatible withGemini
andHybridParallel
and evenflash-attn
Screenshots
Gemini
plugin failureHybridParallel(pp=1,tp=2)
failureHybridParallel(pp=2,tp=1)
successEnvironment
CUDA11.7
accelerate 0.28.0
colossalai 0.3.5
datasets 2.18.0
dropout-layer-norm 0.1
flash-attn 2.5.6
loralib 0.1.2
ninja 1.11.1.1
numpy 1.26.4
packaging 23.2
peft 0.10.0
ray 2.10.0
safetensors 0.4.2
scipy 1.12.0
sentencepiece 0.2.0
okenizers 0.13.3
torch 2.1.2
tqdm 4.66.2
transformers 4.33.0
triton 2.1.0
xformers 0.0.23.post1
The text was updated successfully, but these errors were encountered: