add quant attn v1 support

This commit is contained in:
John Smith 2023-04-25 12:28:45 +08:00
parent f9c94f27cc
commit 633c28fd25
2 changed files with 118 additions and 64 deletions

View File

@ -40,7 +40,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
assert qweight.shape[0] * 8 == x.shape[-1] assert qweight.shape[0] * 8 == x.shape[-1]
outshape = x.shape[:-1] + (qweight.shape[1],) outshape = x.shape[:-1] + (qweight.shape[1],)
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
dtype = x.dtype dtype = x.dtype
x = x.half() x = x.half()
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros) quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
@ -114,7 +114,7 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
if use_new: if use_new:
if auto_switch: if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd: if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half()) output = _matmul4bit_v1_recons(x.half(), qweight, scales.half(), zeros.half())
else: else:
output = _matmul4bit_v1(x, qweight, scales, zeros) output = _matmul4bit_v1(x, qweight, scales, zeros)
else: else:

View File

@ -70,7 +70,7 @@ class QuantLlamaAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def make_quant_attn(model): def make_quant_attn(model, is_v1_model=False):
""" """
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
""" """
@ -84,38 +84,67 @@ def make_quant_attn(model):
k_proj = m.k_proj k_proj = m.k_proj
v_proj = m.v_proj v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) if not is_v1_model:
del q_proj.qweight qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
del k_proj.qweight del q_proj.qweight
del v_proj.qweight del k_proj.qweight
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) del v_proj.qweight
del q_proj.qzeros qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
del k_proj.qzeros del q_proj.qzeros
del v_proj.qzeros del k_proj.qzeros
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) del v_proj.qzeros
del q_proj.scales scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
del k_proj.scales del q_proj.scales
del v_proj.scales del k_proj.scales
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) del v_proj.scales
del q_proj.g_idx g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
del k_proj.g_idx del q_proj.g_idx
del v_proj.g_idx del k_proj.g_idx
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None del v_proj.g_idx
if q_proj.bias is not None: bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
del q_proj.bias if q_proj.bias is not None:
del k_proj.bias del q_proj.bias
del v_proj.bias del k_proj.bias
torch.cuda.empty_cache() del v_proj.bias
torch.cuda.empty_cache()
qkv_layer = Autograd4bitQuantLinear(q_proj.in_features, qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features, out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.groupsize, groupsize=q_proj.groupsize,
is_v1_model=False) is_v1_model=False)
qkv_layer.qweight = qweights qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros qkv_layer.qzeros = qzeros
qkv_layer.scales = scales qkv_layer.scales = scales
qkv_layer.g_idx = g_idx qkv_layer.g_idx = g_idx
qkv_layer.bias = bias qkv_layer.bias = bias
else:
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
del q_proj.qweight
del k_proj.qweight
del v_proj.qweight
zeros = torch.cat([q_proj.zeros, k_proj.zeros, v_proj.zeros], dim=0)
del q_proj.zeros
del k_proj.zeros
del v_proj.zeros
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
del q_proj.scales
del k_proj.scales
del v_proj.scales
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if q_proj.bias is not None:
del q_proj.bias
del k_proj.bias
del v_proj.bias
torch.cuda.empty_cache()
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
groupsize=-1,
is_v1_model=True)
qkv_layer.qweight = qweights
qkv_layer.zeros = zeros
qkv_layer.scales = scales
qkv_layer.bias = bias
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
@ -134,39 +163,64 @@ def make_quant_attn(model):
class QuantLlamaMLP(nn.Module): class QuantLlamaMLP(nn.Module):
def __init__(self, old_module): def __init__(self, old_module, is_v1_model=False):
super().__init__() super().__init__()
gate_proj = old_module.gate_proj gate_proj = old_module.gate_proj
up_proj = old_module.up_proj up_proj = old_module.up_proj
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) if not is_v1_model:
del gate_proj.qweight qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
del up_proj.qweight del gate_proj.qweight
qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1) del up_proj.qweight
del gate_proj.qzeros qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1)
del up_proj.qzeros del gate_proj.qzeros
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1) del up_proj.qzeros
del gate_proj.scales scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1)
del up_proj.scales del gate_proj.scales
g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0) del up_proj.scales
del gate_proj.g_idx g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0)
del up_proj.g_idx del gate_proj.g_idx
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None del up_proj.g_idx
if gate_proj.bias is not None: bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
del gate_proj.bias if gate_proj.bias is not None:
del up_proj.bias del gate_proj.bias
torch.cuda.empty_cache() del up_proj.bias
torch.cuda.empty_cache()
self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features, self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
gate_proj.out_features + up_proj.out_features, out_features=gate_proj.out_features + up_proj.out_features,
gate_proj.groupsize, groupsize=gate_proj.groupsize,
is_v1_model=False) is_v1_model=False)
self.gate_up_proj.qweight = qweights self.gate_up_proj.qweight = qweights
self.gate_up_proj.qzeros = qzeros self.gate_up_proj.qzeros = qzeros
self.gate_up_proj.scales = scales self.gate_up_proj.scales = scales
self.gate_up_proj.g_idx = g_idx self.gate_up_proj.g_idx = g_idx
self.gate_up_proj.bias = bias self.gate_up_proj.bias = bias
else:
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
del gate_proj.qweight
del up_proj.qweight
zeros = torch.cat([gate_proj.zeros, up_proj.zeros], dim=0)
del gate_proj.zeros
del up_proj.zeros
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=0)
del gate_proj.scales
del up_proj.scales
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
if gate_proj.bias is not None:
del gate_proj.bias
del up_proj.bias
torch.cuda.empty_cache()
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
out_features=gate_proj.out_features + up_proj.out_features,
groupsize=gate_proj.groupsize,
is_v1_model=True)
self.gate_up_proj.qweight = qweights
self.gate_up_proj.zeros = zeros
self.gate_up_proj.scales = scales
self.gate_up_proj.bias = bias
self.down_proj = old_module.down_proj self.down_proj = old_module.down_proj
self.act_fn = old_module.act_fn self.act_fn = old_module.act_fn
@ -178,15 +232,15 @@ class QuantLlamaMLP(nn.Module):
return self.down_proj(self.act_fn(gate) * up) return self.down_proj(self.act_fn(gate) * up)
def make_fused_mlp(m, parent_name=''): def make_fused_mlp(m, parent_name='', is_v1_model=False):
""" """
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
""" """
if isinstance(m, LlamaMLP): if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m) return QuantLlamaMLP(m, is_v1_model=is_v1_model)
for name, child in m.named_children(): for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}", is_v1_model=is_v1_model)
if isinstance(child, QuantLlamaMLP): if isinstance(child, QuantLlamaMLP):
setattr(m, name, child) setattr(m, name, child)