fix gpt4all training to more closely match the released logic, other small fixes and optimizations

This commit is contained in:
Wing Lian 2023-03-30 19:08:35 -04:00
parent 878eada8dd
commit 8791eaee9a
4 changed files with 124 additions and 50 deletions

View File

@ -47,7 +47,7 @@ class Autograd4bitQuantLinear(nn.Module):
torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
)
self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize),outfeatures)))
self.register_buffer('bias', torch.empty(outfeatures))
self.bias = nn.Parameter(torch.empty(outfeatures))
self.register_buffer(
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
)
@ -57,11 +57,11 @@ class Autograd4bitQuantLinear(nn.Module):
if torch.is_grad_enabled():
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out += self.bias
out.add_(self.bias)
else:
out = mm4b.matmul4bit(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out += self.bias
out.add_(self.bias)
return out

View File

@ -102,11 +102,7 @@ if not ft_config.skip:
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=data.train_data,
eval_dataset=data.val_data,
args=transformers.TrainingArguments(
training_arguments = transformers.TrainingArguments(
per_device_train_batch_size=ft_config.mbatch_size,
gradient_accumulation_steps=ft_config.gradient_accumulation_steps,
warmup_steps=ft_config.warmup_steps,
@ -122,7 +118,13 @@ if not ft_config.skip:
save_total_limit=ft_config.save_total_limit,
load_best_model_at_end=False,
ddp_find_unused_parameters=False if ft_config.ddp else None,
),
)
trainer = transformers.Trainer(
model=model,
train_dataset=data.train_data,
eval_dataset=data.val_data,
args=training_arguments,
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False

View File

@ -35,7 +35,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
if debug:
print('_matmul4bit_v1')
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
outshape = x.shape[:-1] + (qweight.shape[1],)
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
@ -58,7 +58,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
if debug:
print('_matmul4bit_v2')
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
outshape = x.shape[:-1] + (qweight.shape[1],)
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
@ -95,7 +95,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False)
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
if not transpose:
output = torch.matmul(x, buffer)
if transpose:
else:
output = torch.matmul(x, buffer.T)
return output

View File

@ -1,6 +1,10 @@
import torch
from abc import ABC, abstractmethod
from typing import Dict, Any
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator
import os
@ -186,16 +190,84 @@ class TrainSAD(ATrainData):
return self.tokenize(prompt, **kwargs)
# GPT4All-like Data
class TrainGPT4All(TrainSAD):
# Auxiliary methods
def generate_prompt(self, data_point, **kwargs):
return "{0}\n\n{1}\n{2}\n\n{3}\n{4}\n\n{5}\n{6}".format(
"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.",
"### Instruction:",
data_point["prompt"],
"### Input:",
"",
"### Response:",
data_point["response"]
)
class TrainGPT4All(ATrainData):
def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None:
super().__init__(dataset, val_set_size, tokenizer, cutoff_len)
def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]:
pass
def tokenize_inputs(self, examples):
max_length = self.cutoff_len
input_ids = torch.full((len(examples["prompt"]), max_length), self.tokenizer.pad_token_id)
# ignore bos
newline_tokens = self.tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:]
out = {"labels": [], "attention_mask": []}
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
input_tokens = self.tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze()
if input_tokens.dim() == 0:
input_tokens = input_tokens.unsqueeze(0)
input_len = len(input_tokens)
# plus one since we remove bos from response
# but we subtract one since we want to add eos token
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
# remove bos
target_tokens = self.tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
input_ids[i, :input_len] = input_tokens
# add newline between prompt and response
newline_plus_inputs = input_len + len(newline_tokens)
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
# add target tokens, remove bos
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
# add eos token, enforce stopping if we don't truncate
# we don't want long code to stop generating if truncated during training
if newline_plus_inputs + len(target_tokens) < max_length:
input_ids[i, newline_plus_inputs + len(target_tokens)] = self.tokenizer.eos_token_id
labels = input_ids[i].clone()
labels[: newline_plus_inputs] = -100
labels[labels == self.tokenizer.pad_token_id] = -100
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response
attention_mask = input_ids[i].ne(self.tokenizer.pad_token_id).int()
out["labels"].append(labels)
out["attention_mask"].append(attention_mask)
out["input_ids"] = input_ids
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
return out
def prepare_data(self, **kwargs) -> None:
dataset = load_dataset("json", data_files=self.dataset)
self.val_data = None
if self.val_set_size > 0:
dataset = dataset["train"].train_test_split(
test_size=self.val_set_size, shuffle=True, seed=42 # ! Seed = 42 (?)
)
train_dataset, val_dataset = dataset["train"], dataset["test"]
# tokenize inputs and return labels and attention mask
val_dataset = val_dataset.map(
lambda ele: self.tokenize_inputs(ele),
batched=True,
remove_columns=["source", "prompt"],
)
self.val_data = val_dataset.with_format("torch")
else:
train_dataset = dataset["train"]
train_dataset = train_dataset.map(
lambda ele: self.tokenize_inputs(ele),
batched=True,
remove_columns=["source", "prompt"],
)
self.train_data = train_dataset.with_format("torch")