update multi gpu support in finetune.py

This commit is contained in:
John Smith 2023-04-03 23:55:58 +08:00
parent 5655f218ed
commit 86387a0a35
1 changed files with 4 additions and 1 deletions

View File

@ -59,7 +59,10 @@ lora_config = LoraConfig(
if ft_config.lora_apply_dir is None:
model = get_peft_model(model, lora_config)
else:
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) # ! Direct copy from inference.py
if ft_config.ddp:
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map="auto", torch_dtype=torch.float32) # ! Direct copy from inference.py
else:
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32)
print(ft_config.lora_apply_dir, 'loaded')