add quant attn v1 support

This commit is contained in:
John Smith 2023-04-25 12:28:45 +08:00
parent f9c94f27cc
commit 633c28fd25
2 changed files with 118 additions and 64 deletions

View File

@ -40,7 +40,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = x.shape[:-1] + (qweight.shape[1],)
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
dtype = x.dtype
x = x.half()
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
@ -114,7 +114,7 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half())
output = _matmul4bit_v1_recons(x.half(), qweight, scales.half(), zeros.half())
else:
output = _matmul4bit_v1(x, qweight, scales, zeros)
else:

View File

@ -70,7 +70,7 @@ class QuantLlamaAttention(nn.Module):
return attn_output, attn_weights, past_key_value
def make_quant_attn(model):
def make_quant_attn(model, is_v1_model=False):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
@ -84,6 +84,7 @@ def make_quant_attn(model):
k_proj = m.k_proj
v_proj = m.v_proj
if not is_v1_model:
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
del q_proj.qweight
del k_proj.qweight
@ -107,15 +108,43 @@ def make_quant_attn(model):
del v_proj.bias
torch.cuda.empty_cache()
qkv_layer = Autograd4bitQuantLinear(q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.groupsize,
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
groupsize=q_proj.groupsize,
is_v1_model=False)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.g_idx = g_idx
qkv_layer.bias = bias
else:
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
del q_proj.qweight
del k_proj.qweight
del v_proj.qweight
zeros = torch.cat([q_proj.zeros, k_proj.zeros, v_proj.zeros], dim=0)
del q_proj.zeros
del k_proj.zeros
del v_proj.zeros
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
del q_proj.scales
del k_proj.scales
del v_proj.scales
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if q_proj.bias is not None:
del q_proj.bias
del k_proj.bias
del v_proj.bias
torch.cuda.empty_cache()
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
groupsize=-1,
is_v1_model=True)
qkv_layer.qweight = qweights
qkv_layer.zeros = zeros
qkv_layer.scales = scales
qkv_layer.bias = bias
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
@ -134,12 +163,13 @@ def make_quant_attn(model):
class QuantLlamaMLP(nn.Module):
def __init__(self, old_module):
def __init__(self, old_module, is_v1_model=False):
super().__init__()
gate_proj = old_module.gate_proj
up_proj = old_module.up_proj
if not is_v1_model:
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
del gate_proj.qweight
del up_proj.qweight
@ -158,15 +188,39 @@ class QuantLlamaMLP(nn.Module):
del up_proj.bias
torch.cuda.empty_cache()
self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features,
gate_proj.out_features + up_proj.out_features,
gate_proj.groupsize,
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
out_features=gate_proj.out_features + up_proj.out_features,
groupsize=gate_proj.groupsize,
is_v1_model=False)
self.gate_up_proj.qweight = qweights
self.gate_up_proj.qzeros = qzeros
self.gate_up_proj.scales = scales
self.gate_up_proj.g_idx = g_idx
self.gate_up_proj.bias = bias
else:
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
del gate_proj.qweight
del up_proj.qweight
zeros = torch.cat([gate_proj.zeros, up_proj.zeros], dim=0)
del gate_proj.zeros
del up_proj.zeros
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=0)
del gate_proj.scales
del up_proj.scales
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
if gate_proj.bias is not None:
del gate_proj.bias
del up_proj.bias
torch.cuda.empty_cache()
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
out_features=gate_proj.out_features + up_proj.out_features,
groupsize=gate_proj.groupsize,
is_v1_model=True)
self.gate_up_proj.qweight = qweights
self.gate_up_proj.zeros = zeros
self.gate_up_proj.scales = scales
self.gate_up_proj.bias = bias
self.down_proj = old_module.down_proj
self.act_fn = old_module.act_fn
@ -178,15 +232,15 @@ class QuantLlamaMLP(nn.Module):
return self.down_proj(self.act_fn(gate) * up)
def make_fused_mlp(m, parent_name=''):
def make_fused_mlp(m, parent_name='', is_v1_model=False):
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m)
return QuantLlamaMLP(m, is_v1_model=is_v1_model)
for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}", is_v1_model=is_v1_model)
if isinstance(child, QuantLlamaMLP):
setattr(m, name, child)