model parallelism
This commit is contained in:
parent
cd1a299ba3
commit
2bc64597aa
|
|
@ -15,6 +15,8 @@ auto_switch_thd = 16
|
||||||
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
||||||
if shape_of_qweight not in buffer_mat_dic.keys():
|
if shape_of_qweight not in buffer_mat_dic.keys():
|
||||||
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
||||||
|
elif buffer_mat_dic[shape_of_qweight].device != device:
|
||||||
|
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
|
||||||
return buffer_mat_dic[shape_of_qweight]
|
return buffer_mat_dic[shape_of_qweight]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -217,8 +219,13 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant_for_4bit_autograd(model, layers)
|
make_quant_for_4bit_autograd(model, layers)
|
||||||
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
|
model = accelerate.load_checkpoint_and_dispatch(
|
||||||
model.cuda()
|
model=model,
|
||||||
|
checkpoint=model_path,
|
||||||
|
device_map='auto',
|
||||||
|
no_split_module_classes=["LlamaDecoderLayer"]
|
||||||
|
)
|
||||||
|
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
|
|
||||||
if half:
|
if half:
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,11 @@ if not ft_config.skip:
|
||||||
from gradient_checkpointing import apply_gradient_checkpointing
|
from gradient_checkpointing import apply_gradient_checkpointing
|
||||||
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
||||||
|
|
||||||
|
# Disable Trainer's DataParallel for multigpu
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
model.is_parallelizable = True
|
||||||
|
model.model_parallel = True
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=data.train_data,
|
train_dataset=data.train_data,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue