distributed data parallelism with torchrun
This commit is contained in:
parent
2bc64597aa
commit
8e471516b8
|
|
@ -0,0 +1,3 @@
|
|||
alpaca_lora/
|
||||
repository/
|
||||
__pycache__/
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
class Finetune4bConfig:
|
||||
"""Config holder for LLaMA 4bit finetuning
|
||||
"""
|
||||
|
|
@ -64,6 +65,12 @@ class Finetune4bConfig:
|
|||
self.logging_steps = logging_steps
|
||||
self.checkpoint = checkpoint
|
||||
self.skip = skip
|
||||
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
self.ddp = self.world_size != 1
|
||||
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
|
||||
if self.ddp:
|
||||
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
@ -74,5 +81,6 @@ class Finetune4bConfig:
|
|||
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
|
||||
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
|
||||
f"{self.logging_steps=}\n" +\
|
||||
f"{self.checkpoint=}\n{self.skip=}"
|
||||
f"{self.checkpoint=}\n{self.skip=}\n" +\
|
||||
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}"
|
||||
return s.replace("self.", "")
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ def model_to_float(model):
|
|||
print('Converted as Float.')
|
||||
|
||||
|
||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
|
||||
import transformers
|
||||
import accelerate
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
|
@ -222,7 +222,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
|||
model = accelerate.load_checkpoint_and_dispatch(
|
||||
model=model,
|
||||
checkpoint=model_path,
|
||||
device_map='auto',
|
||||
device_map=device_map,
|
||||
no_split_module_classes=["LlamaDecoderLayer"]
|
||||
)
|
||||
|
||||
|
|
|
|||
10
finetune.py
10
finetune.py
|
|
@ -38,13 +38,14 @@ import train_data
|
|||
ft_config = get_config()
|
||||
|
||||
# * Show loaded parameters
|
||||
print(f"{ft_config}\n")
|
||||
if ft_config.local_rank == 0:
|
||||
print(f"{ft_config}\n")
|
||||
|
||||
if ft_config.gradient_checkpointing:
|
||||
print('Disable Dropout.')
|
||||
|
||||
# Load Basic Model
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model)
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
|
||||
|
||||
# Config Lora
|
||||
lora_config = LoraConfig(
|
||||
|
|
@ -93,7 +94,7 @@ if not ft_config.skip:
|
|||
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
||||
|
||||
# Disable Trainer's DataParallel for multigpu
|
||||
if torch.cuda.device_count() > 1:
|
||||
if not ft_config.ddp and torch.cuda.device_count() > 1:
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
|
||||
|
|
@ -115,7 +116,8 @@ if not ft_config.skip:
|
|||
save_steps=ft_config.save_steps,
|
||||
output_dir=ft_config.lora_out_dir,
|
||||
save_total_limit=ft_config.save_total_limit,
|
||||
load_best_model_at_end=False
|
||||
load_best_model_at_end=False,
|
||||
ddp_find_unused_parameters=False if ft_config.ddp else None,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue