Merge pull request #20 from kooshi/multi-gpu
Enable model parallelism and distributed data parallelism for multi-gpu setups
This commit is contained in:
commit
82dd6dd13e
|
|
@ -0,0 +1,3 @@
|
|||
alpaca_lora/
|
||||
repository/
|
||||
__pycache__/
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
class Finetune4bConfig:
|
||||
"""Config holder for LLaMA 4bit finetuning
|
||||
"""
|
||||
|
|
@ -64,6 +65,12 @@ class Finetune4bConfig:
|
|||
self.logging_steps = logging_steps
|
||||
self.checkpoint = checkpoint
|
||||
self.skip = skip
|
||||
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
self.ddp = self.world_size != 1
|
||||
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
|
||||
if self.ddp:
|
||||
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
@ -74,5 +81,6 @@ class Finetune4bConfig:
|
|||
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
|
||||
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
|
||||
f"{self.logging_steps=}\n" +\
|
||||
f"{self.checkpoint=}\n{self.skip=}"
|
||||
f"{self.checkpoint=}\n{self.skip=}\n" +\
|
||||
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}"
|
||||
return s.replace("self.", "")
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ auto_switch_thd = 16
|
|||
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
||||
if shape_of_qweight not in buffer_mat_dic.keys():
|
||||
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
||||
elif buffer_mat_dic[shape_of_qweight].device != device:
|
||||
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
|
||||
return buffer_mat_dic[shape_of_qweight]
|
||||
|
||||
|
||||
|
|
@ -195,7 +197,7 @@ def model_to_float(model):
|
|||
print('Converted as Float.')
|
||||
|
||||
|
||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
|
||||
import transformers
|
||||
import accelerate
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
|
@ -217,8 +219,13 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
|||
if name in layers:
|
||||
del layers[name]
|
||||
make_quant_for_4bit_autograd(model, layers)
|
||||
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
|
||||
model.cuda()
|
||||
model = accelerate.load_checkpoint_and_dispatch(
|
||||
model=model,
|
||||
checkpoint=model_path,
|
||||
device_map=device_map,
|
||||
no_split_module_classes=["LlamaDecoderLayer"]
|
||||
)
|
||||
|
||||
model.seqlen = 2048
|
||||
|
||||
if half:
|
||||
|
|
|
|||
13
finetune.py
13
finetune.py
|
|
@ -38,13 +38,14 @@ import train_data
|
|||
ft_config = get_config()
|
||||
|
||||
# * Show loaded parameters
|
||||
print(f"{ft_config}\n")
|
||||
if ft_config.local_rank == 0:
|
||||
print(f"{ft_config}\n")
|
||||
|
||||
if ft_config.gradient_checkpointing:
|
||||
print('Disable Dropout.')
|
||||
|
||||
# Load Basic Model
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model)
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
|
||||
|
||||
# Config Lora
|
||||
lora_config = LoraConfig(
|
||||
|
|
@ -92,6 +93,11 @@ if not ft_config.skip:
|
|||
from gradient_checkpointing import apply_gradient_checkpointing
|
||||
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
||||
|
||||
# Disable Trainer's DataParallel for multigpu
|
||||
if not ft_config.ddp and torch.cuda.device_count() > 1:
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=data.train_data,
|
||||
|
|
@ -110,7 +116,8 @@ if not ft_config.skip:
|
|||
save_steps=ft_config.save_steps,
|
||||
output_dir=ft_config.lora_out_dir,
|
||||
save_total_limit=ft_config.save_total_limit,
|
||||
load_best_model_at_end=False
|
||||
load_best_model_at_end=False,
|
||||
ddp_find_unused_parameters=False if ft_config.ddp else None,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue