fix continue training for this version

This commit is contained in:
John Smith 2023-04-17 14:16:05 +08:00
parent e64ff9facd
commit 90e628121a
1 changed files with 4 additions and 1 deletions

View File

@ -37,6 +37,7 @@ else:
autograd_4bit.switch_backend_to('cuda')
import sys
import os
import peft
import peft.tuners.lora
@ -44,7 +45,7 @@ import peft.tuners.lora
import torch
import transformers
from autograd_4bit import load_llama_model_4bit_low_ram
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel, set_peft_model_state_dict
# ! Config
import train_data
@ -168,6 +169,8 @@ if not ft_config.skip:
# Run Trainer
if ft_config.resume_checkpoint:
print('Resuming from {} ...'.format(ft_config.resume_checkpoint))
state_dict_peft = torch.load(os.path.join(ft_config.resume_checkpoint, 'pytorch_model.bin'), map_location='cpu')
set_peft_model_state_dict(model, state_dict_peft)
trainer.train(ft_config.resume_checkpoint)
else:
trainer.train()