add quant attn v1 support
This commit is contained in:
parent
f9c94f27cc
commit
633c28fd25
|
|
@ -40,7 +40,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
|
|||
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||
outshape = x.shape[:-1] + (qweight.shape[1],)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
|
||||
dtype = x.dtype
|
||||
x = x.half()
|
||||
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
|
||||
|
|
@ -114,7 +114,7 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
|
|||
if use_new:
|
||||
if auto_switch:
|
||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half())
|
||||
output = _matmul4bit_v1_recons(x.half(), qweight, scales.half(), zeros.half())
|
||||
else:
|
||||
output = _matmul4bit_v1(x, qweight, scales, zeros)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class QuantLlamaAttention(nn.Module):
|
|||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def make_quant_attn(model):
|
||||
def make_quant_attn(model, is_v1_model=False):
|
||||
"""
|
||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||
"""
|
||||
|
|
@ -84,6 +84,7 @@ def make_quant_attn(model):
|
|||
k_proj = m.k_proj
|
||||
v_proj = m.v_proj
|
||||
|
||||
if not is_v1_model:
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
del q_proj.qweight
|
||||
del k_proj.qweight
|
||||
|
|
@ -107,15 +108,43 @@ def make_quant_attn(model):
|
|||
del v_proj.bias
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
qkv_layer = Autograd4bitQuantLinear(q_proj.in_features,
|
||||
q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
||||
q_proj.groupsize,
|
||||
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
|
||||
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
||||
groupsize=q_proj.groupsize,
|
||||
is_v1_model=False)
|
||||
qkv_layer.qweight = qweights
|
||||
qkv_layer.qzeros = qzeros
|
||||
qkv_layer.scales = scales
|
||||
qkv_layer.g_idx = g_idx
|
||||
qkv_layer.bias = bias
|
||||
else:
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
del q_proj.qweight
|
||||
del k_proj.qweight
|
||||
del v_proj.qweight
|
||||
zeros = torch.cat([q_proj.zeros, k_proj.zeros, v_proj.zeros], dim=0)
|
||||
del q_proj.zeros
|
||||
del k_proj.zeros
|
||||
del v_proj.zeros
|
||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
|
||||
del q_proj.scales
|
||||
del k_proj.scales
|
||||
del v_proj.scales
|
||||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||
if q_proj.bias is not None:
|
||||
del q_proj.bias
|
||||
del k_proj.bias
|
||||
del v_proj.bias
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
|
||||
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
||||
groupsize=-1,
|
||||
is_v1_model=True)
|
||||
qkv_layer.qweight = qweights
|
||||
qkv_layer.zeros = zeros
|
||||
qkv_layer.scales = scales
|
||||
qkv_layer.bias = bias
|
||||
|
||||
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
|
||||
|
||||
|
|
@ -134,12 +163,13 @@ def make_quant_attn(model):
|
|||
|
||||
|
||||
class QuantLlamaMLP(nn.Module):
|
||||
def __init__(self, old_module):
|
||||
def __init__(self, old_module, is_v1_model=False):
|
||||
super().__init__()
|
||||
|
||||
gate_proj = old_module.gate_proj
|
||||
up_proj = old_module.up_proj
|
||||
|
||||
if not is_v1_model:
|
||||
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
|
||||
del gate_proj.qweight
|
||||
del up_proj.qweight
|
||||
|
|
@ -158,15 +188,39 @@ class QuantLlamaMLP(nn.Module):
|
|||
del up_proj.bias
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features,
|
||||
gate_proj.out_features + up_proj.out_features,
|
||||
gate_proj.groupsize,
|
||||
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
|
||||
out_features=gate_proj.out_features + up_proj.out_features,
|
||||
groupsize=gate_proj.groupsize,
|
||||
is_v1_model=False)
|
||||
self.gate_up_proj.qweight = qweights
|
||||
self.gate_up_proj.qzeros = qzeros
|
||||
self.gate_up_proj.scales = scales
|
||||
self.gate_up_proj.g_idx = g_idx
|
||||
self.gate_up_proj.bias = bias
|
||||
else:
|
||||
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
|
||||
del gate_proj.qweight
|
||||
del up_proj.qweight
|
||||
zeros = torch.cat([gate_proj.zeros, up_proj.zeros], dim=0)
|
||||
del gate_proj.zeros
|
||||
del up_proj.zeros
|
||||
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=0)
|
||||
del gate_proj.scales
|
||||
del up_proj.scales
|
||||
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
|
||||
if gate_proj.bias is not None:
|
||||
del gate_proj.bias
|
||||
del up_proj.bias
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
|
||||
out_features=gate_proj.out_features + up_proj.out_features,
|
||||
groupsize=gate_proj.groupsize,
|
||||
is_v1_model=True)
|
||||
self.gate_up_proj.qweight = qweights
|
||||
self.gate_up_proj.zeros = zeros
|
||||
self.gate_up_proj.scales = scales
|
||||
self.gate_up_proj.bias = bias
|
||||
|
||||
self.down_proj = old_module.down_proj
|
||||
self.act_fn = old_module.act_fn
|
||||
|
|
@ -178,15 +232,15 @@ class QuantLlamaMLP(nn.Module):
|
|||
return self.down_proj(self.act_fn(gate) * up)
|
||||
|
||||
|
||||
def make_fused_mlp(m, parent_name=''):
|
||||
def make_fused_mlp(m, parent_name='', is_v1_model=False):
|
||||
"""
|
||||
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
|
||||
"""
|
||||
if isinstance(m, LlamaMLP):
|
||||
return QuantLlamaMLP(m)
|
||||
return QuantLlamaMLP(m, is_v1_model=is_v1_model)
|
||||
|
||||
for name, child in m.named_children():
|
||||
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
|
||||
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}", is_v1_model=is_v1_model)
|
||||
|
||||
if isinstance(child, QuantLlamaMLP):
|
||||
setattr(m, name, child)
|
||||
|
|
|
|||
Loading…
Reference in New Issue