Merge pull request #20 from kooshi/multi-gpu

Enable model parallelism and distributed data parallelism for multi-gpu setups
This commit is contained in:
John Smith 2023-03-25 15:06:01 +08:00 committed by GitHub
commit 82dd6dd13e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 7 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
alpaca_lora/
repository/
__pycache__/

View File

@ -1,3 +1,4 @@
import os
class Finetune4bConfig:
"""Config holder for LLaMA 4bit finetuning
"""
@ -64,6 +65,12 @@ class Finetune4bConfig:
self.logging_steps = logging_steps
self.checkpoint = checkpoint
self.skip = skip
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
self.ddp = self.world_size != 1
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
if self.ddp:
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
def __str__(self) -> str:
@ -74,5 +81,6 @@ class Finetune4bConfig:
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
f"{self.logging_steps=}\n" +\
f"{self.checkpoint=}\n{self.skip=}"
f"{self.checkpoint=}\n{self.skip=}\n" +\
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}"
return s.replace("self.", "")

View File

@ -15,6 +15,8 @@ auto_switch_thd = 16
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
elif buffer_mat_dic[shape_of_qweight].device != device:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
return buffer_mat_dic[shape_of_qweight]
@ -195,7 +197,7 @@ def model_to_float(model):
print('Converted as Float.')
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
import transformers
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
@ -217,8 +219,13 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers)
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
model.cuda()
model = accelerate.load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
device_map=device_map,
no_split_module_classes=["LlamaDecoderLayer"]
)
model.seqlen = 2048
if half:

View File

@ -38,13 +38,14 @@ import train_data
ft_config = get_config()
# * Show loaded parameters
print(f"{ft_config}\n")
if ft_config.local_rank == 0:
print(f"{ft_config}\n")
if ft_config.gradient_checkpointing:
print('Disable Dropout.')
# Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model)
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
# Config Lora
lora_config = LoraConfig(
@ -92,6 +93,11 @@ if not ft_config.skip:
from gradient_checkpointing import apply_gradient_checkpointing
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
# Disable Trainer's DataParallel for multigpu
if not ft_config.ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=data.train_data,
@ -110,7 +116,8 @@ if not ft_config.skip:
save_steps=ft_config.save_steps,
output_dir=ft_config.lora_out_dir,
save_total_limit=ft_config.save_total_limit,
load_best_model_at_end=False
load_best_model_at_end=False,
ddp_find_unused_parameters=False if ft_config.ddp else None,
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)