diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..85c8fed --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +alpaca_lora/ +repository/ +__pycache__/ diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index c074028..3514060 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -1,3 +1,4 @@ +import os class Finetune4bConfig: """Config holder for LLaMA 4bit finetuning """ @@ -64,6 +65,12 @@ class Finetune4bConfig: self.logging_steps = logging_steps self.checkpoint = checkpoint self.skip = skip + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.ddp = self.world_size != 1 + self.device_map = "auto" if not self.ddp else {"": self.local_rank} + if self.ddp: + self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size def __str__(self) -> str: @@ -74,5 +81,6 @@ class Finetune4bConfig: f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\ f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\ f"{self.logging_steps=}\n" +\ - f"{self.checkpoint=}\n{self.skip=}" + f"{self.checkpoint=}\n{self.skip=}\n" +\ + f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}" return s.replace("self.", "") diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py index c55b7d3..ac90789 100644 --- a/GPTQ-for-LLaMa/autograd_4bit.py +++ b/GPTQ-for-LLaMa/autograd_4bit.py @@ -15,6 +15,8 @@ auto_switch_thd = 16 def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'): if shape_of_qweight not in buffer_mat_dic.keys(): buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device) + elif buffer_mat_dic[shape_of_qweight].device != device: + buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device) return buffer_mat_dic[shape_of_qweight] @@ -195,7 +197,7 @@ def model_to_float(model): print('Converted as Float.') -def load_llama_model_4bit_low_ram(config_path, model_path, half=False): +def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"): import transformers import accelerate from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer @@ -217,8 +219,13 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False): if name in layers: del layers[name] make_quant_for_4bit_autograd(model, layers) - model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto') - model.cuda() + model = accelerate.load_checkpoint_and_dispatch( + model=model, + checkpoint=model_path, + device_map=device_map, + no_split_module_classes=["LlamaDecoderLayer"] + ) + model.seqlen = 2048 if half: diff --git a/finetune.py b/finetune.py index 6224f2e..3f8821e 100644 --- a/finetune.py +++ b/finetune.py @@ -38,13 +38,14 @@ import train_data ft_config = get_config() # * Show loaded parameters -print(f"{ft_config}\n") +if ft_config.local_rank == 0: + print(f"{ft_config}\n") if ft_config.gradient_checkpointing: print('Disable Dropout.') # Load Basic Model -model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model) +model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map) # Config Lora lora_config = LoraConfig( @@ -92,6 +93,11 @@ if not ft_config.skip: from gradient_checkpointing import apply_gradient_checkpointing apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio) + # Disable Trainer's DataParallel for multigpu + if not ft_config.ddp and torch.cuda.device_count() > 1: + model.is_parallelizable = True + model.model_parallel = True + trainer = transformers.Trainer( model=model, train_dataset=data.train_data, @@ -110,7 +116,8 @@ if not ft_config.skip: save_steps=ft_config.save_steps, output_dir=ft_config.lora_out_dir, save_total_limit=ft_config.save_total_limit, - load_best_model_at_end=False + load_best_model_at_end=False, + ddp_find_unused_parameters=False if ft_config.ddp else None, ), data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), )