From 90e628121a1fc4acc36bb1a766525be05d997b98 Mon Sep 17 00:00:00 2001 From: John Smith Date: Mon, 17 Apr 2023 14:16:05 +0800 Subject: [PATCH] fix continue training for this version --- finetune.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/finetune.py b/finetune.py index b53be15..142692b 100644 --- a/finetune.py +++ b/finetune.py @@ -37,6 +37,7 @@ else: autograd_4bit.switch_backend_to('cuda') import sys +import os import peft import peft.tuners.lora @@ -44,7 +45,7 @@ import peft.tuners.lora import torch import transformers from autograd_4bit import load_llama_model_4bit_low_ram -from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel +from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel, set_peft_model_state_dict # ! Config import train_data @@ -168,6 +169,8 @@ if not ft_config.skip: # Run Trainer if ft_config.resume_checkpoint: print('Resuming from {} ...'.format(ft_config.resume_checkpoint)) + state_dict_peft = torch.load(os.path.join(ft_config.resume_checkpoint, 'pytorch_model.bin'), map_location='cpu') + set_peft_model_state_dict(model, state_dict_peft) trainer.train(ft_config.resume_checkpoint) else: trainer.train()