diff --git a/finetune.py b/finetune.py index 998b1d6..eea0e66 100644 --- a/finetune.py +++ b/finetune.py @@ -16,9 +16,13 @@ } ] """ -from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn +# Early load config to replace attn if needed +from arg_parser import get_config +ft_config = get_config() -replace_llama_attn_with_flash_attn() +if ft_config.flash_attention: + from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() import sys @@ -32,10 +36,9 @@ from autograd_4bit import load_llama_model_4bit_low_ram from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel # ! Config -from arg_parser import get_config import train_data -ft_config = get_config() + # * Show loaded parameters if ft_config.local_rank == 0: