distributed data parallelism with torchrun
This commit is contained in:
parent
2bc64597aa
commit
8e471516b8
|
|
@ -0,0 +1,3 @@
|
||||||
|
alpaca_lora/
|
||||||
|
repository/
|
||||||
|
__pycache__/
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
class Finetune4bConfig:
|
class Finetune4bConfig:
|
||||||
"""Config holder for LLaMA 4bit finetuning
|
"""Config holder for LLaMA 4bit finetuning
|
||||||
"""
|
"""
|
||||||
|
|
@ -64,6 +65,12 @@ class Finetune4bConfig:
|
||||||
self.logging_steps = logging_steps
|
self.logging_steps = logging_steps
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.skip = skip
|
self.skip = skip
|
||||||
|
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
self.ddp = self.world_size != 1
|
||||||
|
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
|
||||||
|
if self.ddp:
|
||||||
|
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
||||||
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
@ -74,5 +81,6 @@ class Finetune4bConfig:
|
||||||
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
|
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
|
||||||
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
|
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
|
||||||
f"{self.logging_steps=}\n" +\
|
f"{self.logging_steps=}\n" +\
|
||||||
f"{self.checkpoint=}\n{self.skip=}"
|
f"{self.checkpoint=}\n{self.skip=}\n" +\
|
||||||
|
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}"
|
||||||
return s.replace("self.", "")
|
return s.replace("self.", "")
|
||||||
|
|
|
||||||
|
|
@ -197,7 +197,7 @@ def model_to_float(model):
|
||||||
print('Converted as Float.')
|
print('Converted as Float.')
|
||||||
|
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
|
||||||
import transformers
|
import transformers
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
@ -222,7 +222,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||||
model = accelerate.load_checkpoint_and_dispatch(
|
model = accelerate.load_checkpoint_and_dispatch(
|
||||||
model=model,
|
model=model,
|
||||||
checkpoint=model_path,
|
checkpoint=model_path,
|
||||||
device_map='auto',
|
device_map=device_map,
|
||||||
no_split_module_classes=["LlamaDecoderLayer"]
|
no_split_module_classes=["LlamaDecoderLayer"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
10
finetune.py
10
finetune.py
|
|
@ -38,13 +38,14 @@ import train_data
|
||||||
ft_config = get_config()
|
ft_config = get_config()
|
||||||
|
|
||||||
# * Show loaded parameters
|
# * Show loaded parameters
|
||||||
print(f"{ft_config}\n")
|
if ft_config.local_rank == 0:
|
||||||
|
print(f"{ft_config}\n")
|
||||||
|
|
||||||
if ft_config.gradient_checkpointing:
|
if ft_config.gradient_checkpointing:
|
||||||
print('Disable Dropout.')
|
print('Disable Dropout.')
|
||||||
|
|
||||||
# Load Basic Model
|
# Load Basic Model
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model)
|
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
|
||||||
|
|
||||||
# Config Lora
|
# Config Lora
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
|
|
@ -93,7 +94,7 @@ if not ft_config.skip:
|
||||||
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
||||||
|
|
||||||
# Disable Trainer's DataParallel for multigpu
|
# Disable Trainer's DataParallel for multigpu
|
||||||
if torch.cuda.device_count() > 1:
|
if not ft_config.ddp and torch.cuda.device_count() > 1:
|
||||||
model.is_parallelizable = True
|
model.is_parallelizable = True
|
||||||
model.model_parallel = True
|
model.model_parallel = True
|
||||||
|
|
||||||
|
|
@ -115,7 +116,8 @@ if not ft_config.skip:
|
||||||
save_steps=ft_config.save_steps,
|
save_steps=ft_config.save_steps,
|
||||||
output_dir=ft_config.lora_out_dir,
|
output_dir=ft_config.lora_out_dir,
|
||||||
save_total_limit=ft_config.save_total_limit,
|
save_total_limit=ft_config.save_total_limit,
|
||||||
load_best_model_at_end=False
|
load_best_model_at_end=False,
|
||||||
|
ddp_find_unused_parameters=False if ft_config.ddp else None,
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue