distributed data parallelism with torchrun

This commit is contained in:
kooshi 2023-03-24 23:56:06 -05:00
parent 2bc64597aa
commit 8e471516b8
No known key found for this signature in database
GPG Key ID: 0FC0F3D616C6D45D
4 changed files with 20 additions and 7 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
alpaca_lora/
repository/
__pycache__/

View File

@ -1,3 +1,4 @@
import os
class Finetune4bConfig:
"""Config holder for LLaMA 4bit finetuning
"""
@ -64,6 +65,12 @@ class Finetune4bConfig:
self.logging_steps = logging_steps
self.checkpoint = checkpoint
self.skip = skip
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
self.ddp = self.world_size != 1
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
if self.ddp:
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
def __str__(self) -> str:
@ -74,5 +81,6 @@ class Finetune4bConfig:
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
f"{self.logging_steps=}\n" +\
f"{self.checkpoint=}\n{self.skip=}"
f"{self.checkpoint=}\n{self.skip=}\n" +\
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}"
return s.replace("self.", "")

View File

@ -197,7 +197,7 @@ def model_to_float(model):
print('Converted as Float.')
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
import transformers
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
@ -222,7 +222,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
model = accelerate.load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
device_map='auto',
device_map=device_map,
no_split_module_classes=["LlamaDecoderLayer"]
)

View File

@ -38,13 +38,14 @@ import train_data
ft_config = get_config()
# * Show loaded parameters
print(f"{ft_config}\n")
if ft_config.local_rank == 0:
print(f"{ft_config}\n")
if ft_config.gradient_checkpointing:
print('Disable Dropout.')
# Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model)
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
# Config Lora
lora_config = LoraConfig(
@ -93,7 +94,7 @@ if not ft_config.skip:
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
# Disable Trainer's DataParallel for multigpu
if torch.cuda.device_count() > 1:
if not ft_config.ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
@ -115,7 +116,8 @@ if not ft_config.skip:
save_steps=ft_config.save_steps,
output_dir=ft_config.lora_out_dir,
save_total_limit=ft_config.save_total_limit,
load_best_model_at_end=False
load_best_model_at_end=False,
ddp_find_unused_parameters=False if ft_config.ddp else None,
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)