fix continue training for this version
This commit is contained in:
parent
e64ff9facd
commit
90e628121a
|
|
@ -37,6 +37,7 @@ else:
|
||||||
autograd_4bit.switch_backend_to('cuda')
|
autograd_4bit.switch_backend_to('cuda')
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
import peft
|
import peft
|
||||||
import peft.tuners.lora
|
import peft.tuners.lora
|
||||||
|
|
@ -44,7 +45,7 @@ import peft.tuners.lora
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from autograd_4bit import load_llama_model_4bit_low_ram
|
from autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
|
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel, set_peft_model_state_dict
|
||||||
|
|
||||||
# ! Config
|
# ! Config
|
||||||
import train_data
|
import train_data
|
||||||
|
|
@ -168,6 +169,8 @@ if not ft_config.skip:
|
||||||
# Run Trainer
|
# Run Trainer
|
||||||
if ft_config.resume_checkpoint:
|
if ft_config.resume_checkpoint:
|
||||||
print('Resuming from {} ...'.format(ft_config.resume_checkpoint))
|
print('Resuming from {} ...'.format(ft_config.resume_checkpoint))
|
||||||
|
state_dict_peft = torch.load(os.path.join(ft_config.resume_checkpoint, 'pytorch_model.bin'), map_location='cpu')
|
||||||
|
set_peft_model_state_dict(model, state_dict_peft)
|
||||||
trainer.train(ft_config.resume_checkpoint)
|
trainer.train(ft_config.resume_checkpoint)
|
||||||
else:
|
else:
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue