add assert

This commit is contained in:
John Smith 2023-04-21 10:24:58 +08:00
parent 1a0c63edaf
commit 35caccd376
1 changed files with 3 additions and 0 deletions

View File

@ -57,6 +57,9 @@ if ft_config.local_rank == 0:
if ft_config.gradient_checkpointing: if ft_config.gradient_checkpointing:
print('Disable Dropout.') print('Disable Dropout.')
if ft_config.mbatch_size > ft_config.batch_size:
raise Exception('batch_size need to be larger than mbatch_size.')
# Load Basic Model # Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir,
ft_config.llama_q4_model, ft_config.llama_q4_model,