model parallelism
This commit is contained in:
parent
cd1a299ba3
commit
2bc64597aa
|
|
@ -15,6 +15,8 @@ auto_switch_thd = 16
|
|||
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
||||
if shape_of_qweight not in buffer_mat_dic.keys():
|
||||
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
||||
elif buffer_mat_dic[shape_of_qweight].device != device:
|
||||
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
|
||||
return buffer_mat_dic[shape_of_qweight]
|
||||
|
||||
|
||||
|
|
@ -217,8 +219,13 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
|||
if name in layers:
|
||||
del layers[name]
|
||||
make_quant_for_4bit_autograd(model, layers)
|
||||
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
|
||||
model.cuda()
|
||||
model = accelerate.load_checkpoint_and_dispatch(
|
||||
model=model,
|
||||
checkpoint=model_path,
|
||||
device_map='auto',
|
||||
no_split_module_classes=["LlamaDecoderLayer"]
|
||||
)
|
||||
|
||||
model.seqlen = 2048
|
||||
|
||||
if half:
|
||||
|
|
|
|||
|
|
@ -92,6 +92,11 @@ if not ft_config.skip:
|
|||
from gradient_checkpointing import apply_gradient_checkpointing
|
||||
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
||||
|
||||
# Disable Trainer's DataParallel for multigpu
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=data.train_data,
|
||||
|
|
|
|||
Loading…
Reference in New Issue