fix continue training for this version
This commit is contained in:
parent
e64ff9facd
commit
90e628121a
|
|
@ -37,6 +37,7 @@ else:
|
|||
autograd_4bit.switch_backend_to('cuda')
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
import peft
|
||||
import peft.tuners.lora
|
||||
|
|
@ -44,7 +45,7 @@ import peft.tuners.lora
|
|||
import torch
|
||||
import transformers
|
||||
from autograd_4bit import load_llama_model_4bit_low_ram
|
||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
|
||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel, set_peft_model_state_dict
|
||||
|
||||
# ! Config
|
||||
import train_data
|
||||
|
|
@ -168,6 +169,8 @@ if not ft_config.skip:
|
|||
# Run Trainer
|
||||
if ft_config.resume_checkpoint:
|
||||
print('Resuming from {} ...'.format(ft_config.resume_checkpoint))
|
||||
state_dict_peft = torch.load(os.path.join(ft_config.resume_checkpoint, 'pytorch_model.bin'), map_location='cpu')
|
||||
set_peft_model_state_dict(model, state_dict_peft)
|
||||
trainer.train(ft_config.resume_checkpoint)
|
||||
else:
|
||||
trainer.train()
|
||||
|
|
|
|||
Loading…
Reference in New Issue