Merge pull request #42 from winglian/multigpu-fix

better multi-gpu support, support gpt4all training data
This commit is contained in:
John Smith 2023-03-30 00:03:27 +08:00 committed by GitHub
commit 32976f91c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 8 deletions

View File

@ -14,7 +14,8 @@ class Finetune4bConfig:
gradient_checkpointing_ratio: float,
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
checkpoint: bool, skip: bool, verbose: bool,
txt_row_thd: int, use_eos_token: bool, groupsize: int
txt_row_thd: int, use_eos_token: bool, groupsize: int,
local_rank: int,
):
"""
Args:
@ -46,6 +47,7 @@ class Finetune4bConfig:
txt_row_thd (int): Custom row thd for txt file
use_eos_token (bool): Use Eos token instead of padding with 0
groupsize (int): Group size of V2 model, use -1 to load V1 model
local_rank (int): local rank if using torch.distributed.launch
"""
self.dataset = dataset
self.ds_type = ds_type
@ -76,7 +78,7 @@ class Finetune4bConfig:
self.txt_row_thd = txt_row_thd
self.use_eos_token = use_eos_token
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
self.local_rank = int(os.environ.get("LOCAL_RANK", local_rank))
self.ddp = self.world_size != 1
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
if self.ddp:

View File

@ -18,7 +18,7 @@ def parse_commandline():
parser_training = parser.add_argument_group("training")
# Config args group
parser_config.add_argument("--ds_type", choices=["txt", "alpaca"], default="alpaca", required=False,
parser_config.add_argument("--ds_type", choices=["txt", "alpaca", "gpt4all"], default="alpaca", required=False,
help="Dataset structure format. Default: %(default)s"
)
parser_config.add_argument("--lora_out_dir", default="alpaca_lora", required=False,
@ -64,6 +64,8 @@ def parse_commandline():
# V2 model support
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model")
parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch")
return vars(parser.parse_args())
@ -97,5 +99,6 @@ def get_config() -> Finetune4bConfig:
verbose=args["verbose"],
txt_row_thd=args["txt_row_thd"],
use_eos_token=args["use_eos_token"]!=0,
groupsize=args["groupsize"]
groupsize=args["groupsize"],
local_rank=args["local_rank"],
)

View File

@ -59,7 +59,7 @@ lora_config = LoraConfig(
if ft_config.lora_apply_dir is None:
model = get_peft_model(model, lora_config)
else:
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) # ! Direct copy from inference.py
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map=ft_config.device_map, torch_dtype=torch.float32) # ! Direct copy from inference.py
print(ft_config.lora_apply_dir, 'loaded')
@ -83,6 +83,9 @@ if not ft_config.skip:
elif ft_config.ds_type == "alpaca" and not ft_config.skip:
#### Stanford Alpaca-like Data
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
elif ft_config.ds_type == "gpt4all" and not ft_config.skip:
#### GPT4All Data
data = train_data.TrainGPT4All(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
else:
raise NotImplementedError("ERROR: Unknown dataset format")
data.prepare_data(thd=ft_config.txt_row_thd, use_eos_token=ft_config.use_eos_token)

View File

@ -184,3 +184,16 @@ class TrainSAD(ATrainData):
def generate_and_tokenize_prompt(self, data_point, **kwargs):
prompt = self.generate_prompt(data_point, **kwargs)
return self.tokenize(prompt, **kwargs)
# GPT4All-like Data
class TrainGPT4All(TrainSAD):
# Auxiliary methods
def generate_prompt(self, data_point, **kwargs):
return "{0}\n\n{1}\n{2}\n\n{3}\n{4}\n\n{5}\n{6}".format(
"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.",
"### Instruction:",
data_point["prompt"],
"### Input:",
"",
"### Response:",
data_point["response"]