diff --git a/finetune.py b/finetune.py index 9874d1b..b53be15 100644 --- a/finetune.py +++ b/finetune.py @@ -172,6 +172,9 @@ if not ft_config.skip: else: trainer.train() + # Restore old model state dict + model.state_dict = old_state_dict + print('Train completed.') # Save Model