Merge pull request #20 from kooshi/multi-gpu

Enable model parallelism and distributed data parallelism for multi-gpu setups
This commit is contained in:
John Smith 2023-03-25 15:06:01 +08:00 committed by GitHub
commit 82dd6dd13e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 7 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
alpaca_lora/
repository/
__pycache__/

View File

@ -1,3 +1,4 @@
import os
class Finetune4bConfig: class Finetune4bConfig:
"""Config holder for LLaMA 4bit finetuning """Config holder for LLaMA 4bit finetuning
""" """
@ -64,6 +65,12 @@ class Finetune4bConfig:
self.logging_steps = logging_steps self.logging_steps = logging_steps
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.skip = skip self.skip = skip
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
self.ddp = self.world_size != 1
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
if self.ddp:
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
def __str__(self) -> str: def __str__(self) -> str:
@ -74,5 +81,6 @@ class Finetune4bConfig:
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\ f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\ f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
f"{self.logging_steps=}\n" +\ f"{self.logging_steps=}\n" +\
f"{self.checkpoint=}\n{self.skip=}" f"{self.checkpoint=}\n{self.skip=}\n" +\
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}"
return s.replace("self.", "") return s.replace("self.", "")

View File

@ -15,6 +15,8 @@ auto_switch_thd = 16
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'): def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if shape_of_qweight not in buffer_mat_dic.keys(): if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device) buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
elif buffer_mat_dic[shape_of_qweight].device != device:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
return buffer_mat_dic[shape_of_qweight] return buffer_mat_dic[shape_of_qweight]
@ -195,7 +197,7 @@ def model_to_float(model):
print('Converted as Float.') print('Converted as Float.')
def load_llama_model_4bit_low_ram(config_path, model_path, half=False): def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
import transformers import transformers
import accelerate import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
@ -217,8 +219,13 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
if name in layers: if name in layers:
del layers[name] del layers[name]
make_quant_for_4bit_autograd(model, layers) make_quant_for_4bit_autograd(model, layers)
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto') model = accelerate.load_checkpoint_and_dispatch(
model.cuda() model=model,
checkpoint=model_path,
device_map=device_map,
no_split_module_classes=["LlamaDecoderLayer"]
)
model.seqlen = 2048 model.seqlen = 2048
if half: if half:

View File

@ -38,13 +38,14 @@ import train_data
ft_config = get_config() ft_config = get_config()
# * Show loaded parameters # * Show loaded parameters
print(f"{ft_config}\n") if ft_config.local_rank == 0:
print(f"{ft_config}\n")
if ft_config.gradient_checkpointing: if ft_config.gradient_checkpointing:
print('Disable Dropout.') print('Disable Dropout.')
# Load Basic Model # Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model) model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
# Config Lora # Config Lora
lora_config = LoraConfig( lora_config = LoraConfig(
@ -92,6 +93,11 @@ if not ft_config.skip:
from gradient_checkpointing import apply_gradient_checkpointing from gradient_checkpointing import apply_gradient_checkpointing
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio) apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
# Disable Trainer's DataParallel for multigpu
if not ft_config.ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=model, model=model,
train_dataset=data.train_data, train_dataset=data.train_data,
@ -110,7 +116,8 @@ if not ft_config.skip:
save_steps=ft_config.save_steps, save_steps=ft_config.save_steps,
output_dir=ft_config.lora_out_dir, output_dir=ft_config.lora_out_dir,
save_total_limit=ft_config.save_total_limit, save_total_limit=ft_config.save_total_limit,
load_best_model_at_end=False load_best_model_at_end=False,
ddp_find_unused_parameters=False if ft_config.ddp else None,
), ),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
) )