Merge pull request #20 from kooshi/multi-gpu
Enable model parallelism and distributed data parallelism for multi-gpu setups
This commit is contained in:
commit
82dd6dd13e
|
|
@ -0,0 +1,3 @@
|
||||||
|
alpaca_lora/
|
||||||
|
repository/
|
||||||
|
__pycache__/
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
class Finetune4bConfig:
|
class Finetune4bConfig:
|
||||||
"""Config holder for LLaMA 4bit finetuning
|
"""Config holder for LLaMA 4bit finetuning
|
||||||
"""
|
"""
|
||||||
|
|
@ -64,6 +65,12 @@ class Finetune4bConfig:
|
||||||
self.logging_steps = logging_steps
|
self.logging_steps = logging_steps
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.skip = skip
|
self.skip = skip
|
||||||
|
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
self.ddp = self.world_size != 1
|
||||||
|
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
|
||||||
|
if self.ddp:
|
||||||
|
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
||||||
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
@ -74,5 +81,6 @@ class Finetune4bConfig:
|
||||||
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
|
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
|
||||||
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
|
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
|
||||||
f"{self.logging_steps=}\n" +\
|
f"{self.logging_steps=}\n" +\
|
||||||
f"{self.checkpoint=}\n{self.skip=}"
|
f"{self.checkpoint=}\n{self.skip=}\n" +\
|
||||||
|
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}"
|
||||||
return s.replace("self.", "")
|
return s.replace("self.", "")
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ auto_switch_thd = 16
|
||||||
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
||||||
if shape_of_qweight not in buffer_mat_dic.keys():
|
if shape_of_qweight not in buffer_mat_dic.keys():
|
||||||
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
||||||
|
elif buffer_mat_dic[shape_of_qweight].device != device:
|
||||||
|
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
|
||||||
return buffer_mat_dic[shape_of_qweight]
|
return buffer_mat_dic[shape_of_qweight]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -195,7 +197,7 @@ def model_to_float(model):
|
||||||
print('Converted as Float.')
|
print('Converted as Float.')
|
||||||
|
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
|
||||||
import transformers
|
import transformers
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
@ -217,8 +219,13 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant_for_4bit_autograd(model, layers)
|
make_quant_for_4bit_autograd(model, layers)
|
||||||
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
|
model = accelerate.load_checkpoint_and_dispatch(
|
||||||
model.cuda()
|
model=model,
|
||||||
|
checkpoint=model_path,
|
||||||
|
device_map=device_map,
|
||||||
|
no_split_module_classes=["LlamaDecoderLayer"]
|
||||||
|
)
|
||||||
|
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
|
|
||||||
if half:
|
if half:
|
||||||
|
|
|
||||||
13
finetune.py
13
finetune.py
|
|
@ -38,13 +38,14 @@ import train_data
|
||||||
ft_config = get_config()
|
ft_config = get_config()
|
||||||
|
|
||||||
# * Show loaded parameters
|
# * Show loaded parameters
|
||||||
print(f"{ft_config}\n")
|
if ft_config.local_rank == 0:
|
||||||
|
print(f"{ft_config}\n")
|
||||||
|
|
||||||
if ft_config.gradient_checkpointing:
|
if ft_config.gradient_checkpointing:
|
||||||
print('Disable Dropout.')
|
print('Disable Dropout.')
|
||||||
|
|
||||||
# Load Basic Model
|
# Load Basic Model
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model)
|
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
|
||||||
|
|
||||||
# Config Lora
|
# Config Lora
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
|
|
@ -92,6 +93,11 @@ if not ft_config.skip:
|
||||||
from gradient_checkpointing import apply_gradient_checkpointing
|
from gradient_checkpointing import apply_gradient_checkpointing
|
||||||
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
||||||
|
|
||||||
|
# Disable Trainer's DataParallel for multigpu
|
||||||
|
if not ft_config.ddp and torch.cuda.device_count() > 1:
|
||||||
|
model.is_parallelizable = True
|
||||||
|
model.model_parallel = True
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=data.train_data,
|
train_dataset=data.train_data,
|
||||||
|
|
@ -110,7 +116,8 @@ if not ft_config.skip:
|
||||||
save_steps=ft_config.save_steps,
|
save_steps=ft_config.save_steps,
|
||||||
output_dir=ft_config.lora_out_dir,
|
output_dir=ft_config.lora_out_dir,
|
||||||
save_total_limit=ft_config.save_total_limit,
|
save_total_limit=ft_config.save_total_limit,
|
||||||
load_best_model_at_end=False
|
load_best_model_at_end=False,
|
||||||
|
ddp_find_unused_parameters=False if ft_config.ddp else None,
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue