From 86387a0a3575c82e689a452c20b2c9a5cc94a0f3 Mon Sep 17 00:00:00 2001 From: John Smith Date: Mon, 3 Apr 2023 23:55:58 +0800 Subject: [PATCH] update multi gpu support in finetune.py --- finetune.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/finetune.py b/finetune.py index f7db502..f374e2b 100644 --- a/finetune.py +++ b/finetune.py @@ -59,7 +59,10 @@ lora_config = LoraConfig( if ft_config.lora_apply_dir is None: model = get_peft_model(model, lora_config) else: - model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) # ! Direct copy from inference.py + if ft_config.ddp: + model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map="auto", torch_dtype=torch.float32) # ! Direct copy from inference.py + else: + model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) print(ft_config.lora_apply_dir, 'loaded')