diff --git a/finetune.py b/finetune.py index 142692b..6cb2278 100644 --- a/finetune.py +++ b/finetune.py @@ -57,6 +57,9 @@ if ft_config.local_rank == 0: if ft_config.gradient_checkpointing: print('Disable Dropout.') +if ft_config.mbatch_size > ft_config.batch_size: + raise Exception('batch_size need to be larger than mbatch_size.') + # Load Basic Model model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model,