fix gpt4all training to more closely match the released logic, other small fixes and optimizations
This commit is contained in:
parent
878eada8dd
commit
8791eaee9a
|
|
@ -30,7 +30,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||||
class Autograd4bitQuantLinear(nn.Module):
|
class Autograd4bitQuantLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, infeatures, outfeatures, groupsize=-1):
|
def __init__(self, infeatures, outfeatures, groupsize=-1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -47,7 +47,7 @@ class Autograd4bitQuantLinear(nn.Module):
|
||||||
torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
|
torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
|
||||||
)
|
)
|
||||||
self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize),outfeatures)))
|
self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize),outfeatures)))
|
||||||
self.register_buffer('bias', torch.empty(outfeatures))
|
self.bias = nn.Parameter(torch.empty(outfeatures))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
|
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
|
||||||
)
|
)
|
||||||
|
|
@ -57,11 +57,11 @@ class Autograd4bitQuantLinear(nn.Module):
|
||||||
if torch.is_grad_enabled():
|
if torch.is_grad_enabled():
|
||||||
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
|
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
|
||||||
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
|
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
|
||||||
out += self.bias
|
out.add_(self.bias)
|
||||||
else:
|
else:
|
||||||
out = mm4b.matmul4bit(x, self.qweight, self.scales,
|
out = mm4b.matmul4bit(x, self.qweight, self.scales,
|
||||||
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
|
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
|
||||||
out += self.bias
|
out.add_(self.bias)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -115,7 +115,7 @@ def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||||
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
|
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
print("Loading Model ...")
|
print("Loading Model ...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
|
|
@ -136,7 +136,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa
|
||||||
)
|
)
|
||||||
|
|
||||||
model.seqlen = seqlen
|
model.seqlen = seqlen
|
||||||
|
|
||||||
if half:
|
if half:
|
||||||
model_to_half(model)
|
model_to_half(model)
|
||||||
|
|
||||||
|
|
@ -144,9 +144,9 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa
|
||||||
tokenizer.truncation_side = 'left'
|
tokenizer.truncation_side = 'left'
|
||||||
|
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
|
def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
@ -190,13 +190,13 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo
|
||||||
m.zeros = m.zeros.half()
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
|
|
||||||
print('Dispatching model ...')
|
print('Dispatching model ...')
|
||||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
||||||
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0)
|
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
print('Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
|
print('Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
|
||||||
|
|
||||||
# rotary_emb fix
|
# rotary_emb fix
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if 'rotary_emb' in n:
|
if 'rotary_emb' in n:
|
||||||
|
|
@ -210,7 +210,7 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo
|
||||||
if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys():
|
if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys():
|
||||||
hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu()
|
hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu()
|
||||||
hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu()
|
hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu()
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(config_path)
|
tokenizer = LlamaTokenizer.from_pretrained(config_path)
|
||||||
tokenizer.truncation_side = 'left'
|
tokenizer.truncation_side = 'left'
|
||||||
|
|
||||||
|
|
|
||||||
36
finetune.py
36
finetune.py
|
|
@ -102,27 +102,29 @@ if not ft_config.skip:
|
||||||
model.is_parallelizable = True
|
model.is_parallelizable = True
|
||||||
model.model_parallel = True
|
model.model_parallel = True
|
||||||
|
|
||||||
|
training_arguments = transformers.TrainingArguments(
|
||||||
|
per_device_train_batch_size=ft_config.mbatch_size,
|
||||||
|
gradient_accumulation_steps=ft_config.gradient_accumulation_steps,
|
||||||
|
warmup_steps=ft_config.warmup_steps,
|
||||||
|
num_train_epochs=ft_config.epochs,
|
||||||
|
learning_rate=ft_config.lr,
|
||||||
|
fp16=True,
|
||||||
|
logging_steps=ft_config.logging_steps,
|
||||||
|
evaluation_strategy="no",
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_steps=None,
|
||||||
|
save_steps=ft_config.save_steps,
|
||||||
|
output_dir=ft_config.lora_out_dir,
|
||||||
|
save_total_limit=ft_config.save_total_limit,
|
||||||
|
load_best_model_at_end=False,
|
||||||
|
ddp_find_unused_parameters=False if ft_config.ddp else None,
|
||||||
|
)
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=data.train_data,
|
train_dataset=data.train_data,
|
||||||
eval_dataset=data.val_data,
|
eval_dataset=data.val_data,
|
||||||
args=transformers.TrainingArguments(
|
args=training_arguments,
|
||||||
per_device_train_batch_size=ft_config.mbatch_size,
|
|
||||||
gradient_accumulation_steps=ft_config.gradient_accumulation_steps,
|
|
||||||
warmup_steps=ft_config.warmup_steps,
|
|
||||||
num_train_epochs=ft_config.epochs,
|
|
||||||
learning_rate=ft_config.lr,
|
|
||||||
fp16=True,
|
|
||||||
logging_steps=ft_config.logging_steps,
|
|
||||||
evaluation_strategy="no",
|
|
||||||
save_strategy="steps",
|
|
||||||
eval_steps=None,
|
|
||||||
save_steps=ft_config.save_steps,
|
|
||||||
output_dir=ft_config.lora_out_dir,
|
|
||||||
save_total_limit=ft_config.save_total_limit,
|
|
||||||
load_best_model_at_end=False,
|
|
||||||
ddp_find_unused_parameters=False if ft_config.ddp else None,
|
|
||||||
),
|
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||||
)
|
)
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
|
||||||
|
|
@ -27,15 +27,15 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
|
||||||
input x: (n, m)
|
input x: (n, m)
|
||||||
qweight: (j, k)
|
qweight: (j, k)
|
||||||
where m == j*8
|
where m == j*8
|
||||||
|
|
||||||
perform x @ qweight
|
perform x @ qweight
|
||||||
|
|
||||||
return y:
|
return y:
|
||||||
"""
|
"""
|
||||||
if debug:
|
if debug:
|
||||||
print('_matmul4bit_v1')
|
print('_matmul4bit_v1')
|
||||||
assert qweight.shape[0] * 8 == x.shape[-1]
|
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
outshape = x.shape[:-1] + (qweight.shape[1],)
|
||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
|
|
@ -50,15 +50,15 @@ def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
|
||||||
input x: (n, m)
|
input x: (n, m)
|
||||||
qweight: (j, k)
|
qweight: (j, k)
|
||||||
where m == j*8
|
where m == j*8
|
||||||
|
|
||||||
perform x @ qweight
|
perform x @ qweight
|
||||||
|
|
||||||
return y:
|
return y:
|
||||||
"""
|
"""
|
||||||
if debug:
|
if debug:
|
||||||
print('_matmul4bit_v2')
|
print('_matmul4bit_v2')
|
||||||
assert qweight.shape[0] * 8 == x.shape[-1]
|
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
outshape = x.shape[:-1] + (qweight.shape[1],)
|
||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
|
|
@ -95,7 +95,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False)
|
||||||
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
|
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
|
||||||
if not transpose:
|
if not transpose:
|
||||||
output = torch.matmul(x, buffer)
|
output = torch.matmul(x, buffer)
|
||||||
if transpose:
|
else:
|
||||||
output = torch.matmul(x, buffer.T)
|
output = torch.matmul(x, buffer.T)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,10 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from datasets import load_dataset, Dataset
|
from datasets import load_dataset, Dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import DefaultDataCollator
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -126,7 +130,7 @@ class TrainTxt(ATrainData):
|
||||||
class TrainSAD(ATrainData):
|
class TrainSAD(ATrainData):
|
||||||
def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None:
|
def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None:
|
||||||
super().__init__(dataset, val_set_size, tokenizer, cutoff_len)
|
super().__init__(dataset, val_set_size, tokenizer, cutoff_len)
|
||||||
|
|
||||||
def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]:
|
def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]:
|
||||||
# there's probably a way to do this with the tokenizer settings
|
# there's probably a way to do this with the tokenizer settings
|
||||||
# but again, gotta move fast
|
# but again, gotta move fast
|
||||||
|
|
@ -186,16 +190,84 @@ class TrainSAD(ATrainData):
|
||||||
return self.tokenize(prompt, **kwargs)
|
return self.tokenize(prompt, **kwargs)
|
||||||
|
|
||||||
# GPT4All-like Data
|
# GPT4All-like Data
|
||||||
class TrainGPT4All(TrainSAD):
|
class TrainGPT4All(ATrainData):
|
||||||
# Auxiliary methods
|
def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None:
|
||||||
def generate_prompt(self, data_point, **kwargs):
|
super().__init__(dataset, val_set_size, tokenizer, cutoff_len)
|
||||||
return "{0}\n\n{1}\n{2}\n\n{3}\n{4}\n\n{5}\n{6}".format(
|
|
||||||
"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.",
|
|
||||||
"### Instruction:",
|
|
||||||
data_point["prompt"],
|
|
||||||
"### Input:",
|
|
||||||
"",
|
|
||||||
"### Response:",
|
|
||||||
data_point["response"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def tokenize_inputs(self, examples):
|
||||||
|
max_length = self.cutoff_len
|
||||||
|
input_ids = torch.full((len(examples["prompt"]), max_length), self.tokenizer.pad_token_id)
|
||||||
|
# ignore bos
|
||||||
|
newline_tokens = self.tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:]
|
||||||
|
|
||||||
|
out = {"labels": [], "attention_mask": []}
|
||||||
|
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
|
||||||
|
input_tokens = self.tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze()
|
||||||
|
if input_tokens.dim() == 0:
|
||||||
|
input_tokens = input_tokens.unsqueeze(0)
|
||||||
|
|
||||||
|
input_len = len(input_tokens)
|
||||||
|
|
||||||
|
# plus one since we remove bos from response
|
||||||
|
# but we subtract one since we want to add eos token
|
||||||
|
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
|
||||||
|
# remove bos
|
||||||
|
target_tokens = self.tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
|
||||||
|
|
||||||
|
input_ids[i, :input_len] = input_tokens
|
||||||
|
# add newline between prompt and response
|
||||||
|
newline_plus_inputs = input_len + len(newline_tokens)
|
||||||
|
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
|
||||||
|
|
||||||
|
# add target tokens, remove bos
|
||||||
|
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
|
||||||
|
# add eos token, enforce stopping if we don't truncate
|
||||||
|
# we don't want long code to stop generating if truncated during training
|
||||||
|
if newline_plus_inputs + len(target_tokens) < max_length:
|
||||||
|
input_ids[i, newline_plus_inputs + len(target_tokens)] = self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
labels = input_ids[i].clone()
|
||||||
|
labels[: newline_plus_inputs] = -100
|
||||||
|
labels[labels == self.tokenizer.pad_token_id] = -100
|
||||||
|
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response
|
||||||
|
|
||||||
|
attention_mask = input_ids[i].ne(self.tokenizer.pad_token_id).int()
|
||||||
|
|
||||||
|
out["labels"].append(labels)
|
||||||
|
out["attention_mask"].append(attention_mask)
|
||||||
|
|
||||||
|
out["input_ids"] = input_ids
|
||||||
|
|
||||||
|
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def prepare_data(self, **kwargs) -> None:
|
||||||
|
dataset = load_dataset("json", data_files=self.dataset)
|
||||||
|
|
||||||
|
self.val_data = None
|
||||||
|
if self.val_set_size > 0:
|
||||||
|
dataset = dataset["train"].train_test_split(
|
||||||
|
test_size=self.val_set_size, shuffle=True, seed=42 # ! Seed = 42 (?)
|
||||||
|
)
|
||||||
|
train_dataset, val_dataset = dataset["train"], dataset["test"]
|
||||||
|
|
||||||
|
# tokenize inputs and return labels and attention mask
|
||||||
|
val_dataset = val_dataset.map(
|
||||||
|
lambda ele: self.tokenize_inputs(ele),
|
||||||
|
batched=True,
|
||||||
|
remove_columns=["source", "prompt"],
|
||||||
|
)
|
||||||
|
self.val_data = val_dataset.with_format("torch")
|
||||||
|
else:
|
||||||
|
train_dataset = dataset["train"]
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
lambda ele: self.tokenize_inputs(ele),
|
||||||
|
batched=True,
|
||||||
|
remove_columns=["source", "prompt"],
|
||||||
|
)
|
||||||
|
self.train_data = train_dataset.with_format("torch")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue