add quant attn v1 support

This commit is contained in:
John Smith 2023-04-25 12:28:45 +08:00
parent f9c94f27cc
commit 633c28fd25
2 changed files with 118 additions and 64 deletions

View File

@ -40,7 +40,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
assert qweight.shape[0] * 8 == x.shape[-1] assert qweight.shape[0] * 8 == x.shape[-1]
outshape = x.shape[:-1] + (qweight.shape[1],) outshape = x.shape[:-1] + (qweight.shape[1],)
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
dtype = x.dtype dtype = x.dtype
x = x.half() x = x.half()
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros) quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
@ -114,7 +114,7 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
if use_new: if use_new:
if auto_switch: if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd: if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half()) output = _matmul4bit_v1_recons(x.half(), qweight, scales.half(), zeros.half())
else: else:
output = _matmul4bit_v1(x, qweight, scales, zeros) output = _matmul4bit_v1(x, qweight, scales, zeros)
else: else:

View File

@ -70,7 +70,7 @@ class QuantLlamaAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def make_quant_attn(model): def make_quant_attn(model, is_v1_model=False):
""" """
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
""" """
@ -84,6 +84,7 @@ def make_quant_attn(model):
k_proj = m.k_proj k_proj = m.k_proj
v_proj = m.v_proj v_proj = m.v_proj
if not is_v1_model:
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
del q_proj.qweight del q_proj.qweight
del k_proj.qweight del k_proj.qweight
@ -107,15 +108,43 @@ def make_quant_attn(model):
del v_proj.bias del v_proj.bias
torch.cuda.empty_cache() torch.cuda.empty_cache()
qkv_layer = Autograd4bitQuantLinear(q_proj.in_features, qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features, out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.groupsize, groupsize=q_proj.groupsize,
is_v1_model=False) is_v1_model=False)
qkv_layer.qweight = qweights qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros qkv_layer.qzeros = qzeros
qkv_layer.scales = scales qkv_layer.scales = scales
qkv_layer.g_idx = g_idx qkv_layer.g_idx = g_idx
qkv_layer.bias = bias qkv_layer.bias = bias
else:
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
del q_proj.qweight
del k_proj.qweight
del v_proj.qweight
zeros = torch.cat([q_proj.zeros, k_proj.zeros, v_proj.zeros], dim=0)
del q_proj.zeros
del k_proj.zeros
del v_proj.zeros
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
del q_proj.scales
del k_proj.scales
del v_proj.scales
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if q_proj.bias is not None:
del q_proj.bias
del k_proj.bias
del v_proj.bias
torch.cuda.empty_cache()
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
groupsize=-1,
is_v1_model=True)
qkv_layer.qweight = qweights
qkv_layer.zeros = zeros
qkv_layer.scales = scales
qkv_layer.bias = bias
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
@ -134,12 +163,13 @@ def make_quant_attn(model):
class QuantLlamaMLP(nn.Module): class QuantLlamaMLP(nn.Module):
def __init__(self, old_module): def __init__(self, old_module, is_v1_model=False):
super().__init__() super().__init__()
gate_proj = old_module.gate_proj gate_proj = old_module.gate_proj
up_proj = old_module.up_proj up_proj = old_module.up_proj
if not is_v1_model:
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
del gate_proj.qweight del gate_proj.qweight
del up_proj.qweight del up_proj.qweight
@ -158,15 +188,39 @@ class QuantLlamaMLP(nn.Module):
del up_proj.bias del up_proj.bias
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features, self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
gate_proj.out_features + up_proj.out_features, out_features=gate_proj.out_features + up_proj.out_features,
gate_proj.groupsize, groupsize=gate_proj.groupsize,
is_v1_model=False) is_v1_model=False)
self.gate_up_proj.qweight = qweights self.gate_up_proj.qweight = qweights
self.gate_up_proj.qzeros = qzeros self.gate_up_proj.qzeros = qzeros
self.gate_up_proj.scales = scales self.gate_up_proj.scales = scales
self.gate_up_proj.g_idx = g_idx self.gate_up_proj.g_idx = g_idx
self.gate_up_proj.bias = bias self.gate_up_proj.bias = bias
else:
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
del gate_proj.qweight
del up_proj.qweight
zeros = torch.cat([gate_proj.zeros, up_proj.zeros], dim=0)
del gate_proj.zeros
del up_proj.zeros
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=0)
del gate_proj.scales
del up_proj.scales
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
if gate_proj.bias is not None:
del gate_proj.bias
del up_proj.bias
torch.cuda.empty_cache()
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
out_features=gate_proj.out_features + up_proj.out_features,
groupsize=gate_proj.groupsize,
is_v1_model=True)
self.gate_up_proj.qweight = qweights
self.gate_up_proj.zeros = zeros
self.gate_up_proj.scales = scales
self.gate_up_proj.bias = bias
self.down_proj = old_module.down_proj self.down_proj = old_module.down_proj
self.act_fn = old_module.act_fn self.act_fn = old_module.act_fn
@ -178,15 +232,15 @@ class QuantLlamaMLP(nn.Module):
return self.down_proj(self.act_fn(gate) * up) return self.down_proj(self.act_fn(gate) * up)
def make_fused_mlp(m, parent_name=''): def make_fused_mlp(m, parent_name='', is_v1_model=False):
""" """
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
""" """
if isinstance(m, LlamaMLP): if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m) return QuantLlamaMLP(m, is_v1_model=is_v1_model)
for name, child in m.named_children(): for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}", is_v1_model=is_v1_model)
if isinstance(child, QuantLlamaMLP): if isinstance(child, QuantLlamaMLP):
setattr(m, name, child) setattr(m, name, child)