Skip to content

Train script fail to Yi-6B, how to adapt? #201

@yhyu13

Description

@yhyu13

Hi,

I am intersted in applying toolbech dataset to Yi-6B. https://huggingface.co/chargoddard/Yi-6B-Llama

The training script has been slightly modified:

export PYTHONPATH=./ && \
        deepspeed --master_port=20001 toolbench/train/train_lora.py \
                --model_name_or_path /root/CodeSpace/Yi-6B-Llama  \
                --data_path  /root/CodeSpace/data/toolllama_G123_dfs_eval.json \
                --eval_data_path  /root/CodeSpace/data/toolllama_G123_dfs_eval.json \
                --conv_template tool-llama-single-round \
                --bf16 True \
                --output_dir toolYi_6B_llama_lora \
                --num_train_epochs 5 \
                --per_device_train_batch_size 4 \
                --per_device_eval_batch_size 2 \
                --gradient_accumulation_steps 2 \
                --evaluation_strategy "epoch" \
                --prediction_loss_only \
                --save_strategy "epoch" \
                --save_total_limit 8 \
                --learning_rate 0.00005 \
                --weight_decay 0 \
                --warmup_ratio 0.04 \
                --lr_scheduler_type "cosine" \
                --logging_steps 1 \
                --source_model_max_length 4096 \
                --model_max_length 4096 \
                --gradient_checkpointing True \
                --lazy_preprocess True \
                --deepspeed ds_configs/stage2.json \
                --report_to none \

But it turns out to have error:

  File "/root/CodeSpace/ToolBench/toolbench/train/llama_flash_attn_monkey_patch.py", line 28, in forward_2
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[4, 4096, 32, 128]' is invalid for input of size 8388608
  0%|  

Does the flash attention code only adapat to llama2 models but not Yi-6B?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions