model parallelism

This commit is contained in:
kooshi 2023-03-24 23:03:43 -05:00
parent cd1a299ba3
commit 2bc64597aa
No known key found for this signature in database
GPG Key ID: 0FC0F3D616C6D45D
2 changed files with 14 additions and 2 deletions

View File

@ -15,6 +15,8 @@ auto_switch_thd = 16
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
elif buffer_mat_dic[shape_of_qweight].device != device:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
return buffer_mat_dic[shape_of_qweight]
@ -217,8 +219,13 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers)
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
model.cuda()
model = accelerate.load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
device_map='auto',
no_split_module_classes=["LlamaDecoderLayer"]
)
model.seqlen = 2048
if half:

View File

@ -92,6 +92,11 @@ if not ft_config.skip:
from gradient_checkpointing import apply_gradient_checkpointing
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
# Disable Trainer's DataParallel for multigpu
if torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=data.train_data,