Update finetune.py
This commit is contained in:
parent
95cd390d25
commit
c5aa7fb695
11
finetune.py
11
finetune.py
|
|
@ -16,9 +16,13 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
# Early load config to replace attn if needed
|
||||||
|
from arg_parser import get_config
|
||||||
|
ft_config = get_config()
|
||||||
|
|
||||||
replace_llama_attn_with_flash_attn()
|
if ft_config.flash_attention:
|
||||||
|
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
||||||
|
replace_llama_attn_with_flash_attn()
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
@ -32,10 +36,9 @@ from autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
|
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
|
||||||
|
|
||||||
# ! Config
|
# ! Config
|
||||||
from arg_parser import get_config
|
|
||||||
import train_data
|
import train_data
|
||||||
|
|
||||||
ft_config = get_config()
|
|
||||||
|
|
||||||
# * Show loaded parameters
|
# * Show loaded parameters
|
||||||
if ft_config.local_rank == 0:
|
if ft_config.local_rank == 0:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue