From c5aa7fb6951ac573089daf7a1aa990bc68270f1f Mon Sep 17 00:00:00 2001 From: yamashi Date: Fri, 7 Apr 2023 00:43:36 +0200 Subject: [PATCH] Update finetune.py --- finetune.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/finetune.py b/finetune.py index 998b1d6..eea0e66 100644 --- a/finetune.py +++ b/finetune.py @@ -16,9 +16,13 @@ } ] """ -from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn +# Early load config to replace attn if needed +from arg_parser import get_config +ft_config = get_config() -replace_llama_attn_with_flash_attn() +if ft_config.flash_attention: + from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() import sys @@ -32,10 +36,9 @@ from autograd_4bit import load_llama_model_4bit_low_ram from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel # ! Config -from arg_parser import get_config import train_data -ft_config = get_config() + # * Show loaded parameters if ft_config.local_rank == 0: