add quant attn v1 support
This commit is contained in:
parent
f9c94f27cc
commit
633c28fd25
|
|
@ -40,7 +40,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
|
||||||
assert qweight.shape[0] * 8 == x.shape[-1]
|
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||||
outshape = x.shape[:-1] + (qweight.shape[1],)
|
outshape = x.shape[:-1] + (qweight.shape[1],)
|
||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
x = x.half()
|
x = x.half()
|
||||||
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
|
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
|
||||||
|
|
@ -114,7 +114,7 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
|
||||||
if use_new:
|
if use_new:
|
||||||
if auto_switch:
|
if auto_switch:
|
||||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||||
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half())
|
output = _matmul4bit_v1_recons(x.half(), qweight, scales.half(), zeros.half())
|
||||||
else:
|
else:
|
||||||
output = _matmul4bit_v1(x, qweight, scales, zeros)
|
output = _matmul4bit_v1(x, qweight, scales, zeros)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ class QuantLlamaAttention(nn.Module):
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def make_quant_attn(model):
|
def make_quant_attn(model, is_v1_model=False):
|
||||||
"""
|
"""
|
||||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||||
"""
|
"""
|
||||||
|
|
@ -84,38 +84,67 @@ def make_quant_attn(model):
|
||||||
k_proj = m.k_proj
|
k_proj = m.k_proj
|
||||||
v_proj = m.v_proj
|
v_proj = m.v_proj
|
||||||
|
|
||||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
if not is_v1_model:
|
||||||
del q_proj.qweight
|
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||||
del k_proj.qweight
|
del q_proj.qweight
|
||||||
del v_proj.qweight
|
del k_proj.qweight
|
||||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
del v_proj.qweight
|
||||||
del q_proj.qzeros
|
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||||
del k_proj.qzeros
|
del q_proj.qzeros
|
||||||
del v_proj.qzeros
|
del k_proj.qzeros
|
||||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
del v_proj.qzeros
|
||||||
del q_proj.scales
|
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||||
del k_proj.scales
|
del q_proj.scales
|
||||||
del v_proj.scales
|
del k_proj.scales
|
||||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
del v_proj.scales
|
||||||
del q_proj.g_idx
|
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||||
del k_proj.g_idx
|
del q_proj.g_idx
|
||||||
del v_proj.g_idx
|
del k_proj.g_idx
|
||||||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
del v_proj.g_idx
|
||||||
if q_proj.bias is not None:
|
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||||
del q_proj.bias
|
if q_proj.bias is not None:
|
||||||
del k_proj.bias
|
del q_proj.bias
|
||||||
del v_proj.bias
|
del k_proj.bias
|
||||||
torch.cuda.empty_cache()
|
del v_proj.bias
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
qkv_layer = Autograd4bitQuantLinear(q_proj.in_features,
|
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
|
||||||
q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
||||||
q_proj.groupsize,
|
groupsize=q_proj.groupsize,
|
||||||
is_v1_model=False)
|
is_v1_model=False)
|
||||||
qkv_layer.qweight = qweights
|
qkv_layer.qweight = qweights
|
||||||
qkv_layer.qzeros = qzeros
|
qkv_layer.qzeros = qzeros
|
||||||
qkv_layer.scales = scales
|
qkv_layer.scales = scales
|
||||||
qkv_layer.g_idx = g_idx
|
qkv_layer.g_idx = g_idx
|
||||||
qkv_layer.bias = bias
|
qkv_layer.bias = bias
|
||||||
|
else:
|
||||||
|
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||||
|
del q_proj.qweight
|
||||||
|
del k_proj.qweight
|
||||||
|
del v_proj.qweight
|
||||||
|
zeros = torch.cat([q_proj.zeros, k_proj.zeros, v_proj.zeros], dim=0)
|
||||||
|
del q_proj.zeros
|
||||||
|
del k_proj.zeros
|
||||||
|
del v_proj.zeros
|
||||||
|
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
|
||||||
|
del q_proj.scales
|
||||||
|
del k_proj.scales
|
||||||
|
del v_proj.scales
|
||||||
|
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||||
|
if q_proj.bias is not None:
|
||||||
|
del q_proj.bias
|
||||||
|
del k_proj.bias
|
||||||
|
del v_proj.bias
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
|
||||||
|
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
||||||
|
groupsize=-1,
|
||||||
|
is_v1_model=True)
|
||||||
|
qkv_layer.qweight = qweights
|
||||||
|
qkv_layer.zeros = zeros
|
||||||
|
qkv_layer.scales = scales
|
||||||
|
qkv_layer.bias = bias
|
||||||
|
|
||||||
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
|
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
|
||||||
|
|
||||||
|
|
@ -134,39 +163,64 @@ def make_quant_attn(model):
|
||||||
|
|
||||||
|
|
||||||
class QuantLlamaMLP(nn.Module):
|
class QuantLlamaMLP(nn.Module):
|
||||||
def __init__(self, old_module):
|
def __init__(self, old_module, is_v1_model=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
gate_proj = old_module.gate_proj
|
gate_proj = old_module.gate_proj
|
||||||
up_proj = old_module.up_proj
|
up_proj = old_module.up_proj
|
||||||
|
|
||||||
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
|
if not is_v1_model:
|
||||||
del gate_proj.qweight
|
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
|
||||||
del up_proj.qweight
|
del gate_proj.qweight
|
||||||
qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1)
|
del up_proj.qweight
|
||||||
del gate_proj.qzeros
|
qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1)
|
||||||
del up_proj.qzeros
|
del gate_proj.qzeros
|
||||||
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1)
|
del up_proj.qzeros
|
||||||
del gate_proj.scales
|
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1)
|
||||||
del up_proj.scales
|
del gate_proj.scales
|
||||||
g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0)
|
del up_proj.scales
|
||||||
del gate_proj.g_idx
|
g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0)
|
||||||
del up_proj.g_idx
|
del gate_proj.g_idx
|
||||||
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
|
del up_proj.g_idx
|
||||||
if gate_proj.bias is not None:
|
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
|
||||||
del gate_proj.bias
|
if gate_proj.bias is not None:
|
||||||
del up_proj.bias
|
del gate_proj.bias
|
||||||
torch.cuda.empty_cache()
|
del up_proj.bias
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features,
|
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
|
||||||
gate_proj.out_features + up_proj.out_features,
|
out_features=gate_proj.out_features + up_proj.out_features,
|
||||||
gate_proj.groupsize,
|
groupsize=gate_proj.groupsize,
|
||||||
is_v1_model=False)
|
is_v1_model=False)
|
||||||
self.gate_up_proj.qweight = qweights
|
self.gate_up_proj.qweight = qweights
|
||||||
self.gate_up_proj.qzeros = qzeros
|
self.gate_up_proj.qzeros = qzeros
|
||||||
self.gate_up_proj.scales = scales
|
self.gate_up_proj.scales = scales
|
||||||
self.gate_up_proj.g_idx = g_idx
|
self.gate_up_proj.g_idx = g_idx
|
||||||
self.gate_up_proj.bias = bias
|
self.gate_up_proj.bias = bias
|
||||||
|
else:
|
||||||
|
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
|
||||||
|
del gate_proj.qweight
|
||||||
|
del up_proj.qweight
|
||||||
|
zeros = torch.cat([gate_proj.zeros, up_proj.zeros], dim=0)
|
||||||
|
del gate_proj.zeros
|
||||||
|
del up_proj.zeros
|
||||||
|
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=0)
|
||||||
|
del gate_proj.scales
|
||||||
|
del up_proj.scales
|
||||||
|
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
|
||||||
|
if gate_proj.bias is not None:
|
||||||
|
del gate_proj.bias
|
||||||
|
del up_proj.bias
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
|
||||||
|
out_features=gate_proj.out_features + up_proj.out_features,
|
||||||
|
groupsize=gate_proj.groupsize,
|
||||||
|
is_v1_model=True)
|
||||||
|
self.gate_up_proj.qweight = qweights
|
||||||
|
self.gate_up_proj.zeros = zeros
|
||||||
|
self.gate_up_proj.scales = scales
|
||||||
|
self.gate_up_proj.bias = bias
|
||||||
|
|
||||||
self.down_proj = old_module.down_proj
|
self.down_proj = old_module.down_proj
|
||||||
self.act_fn = old_module.act_fn
|
self.act_fn = old_module.act_fn
|
||||||
|
|
@ -178,15 +232,15 @@ class QuantLlamaMLP(nn.Module):
|
||||||
return self.down_proj(self.act_fn(gate) * up)
|
return self.down_proj(self.act_fn(gate) * up)
|
||||||
|
|
||||||
|
|
||||||
def make_fused_mlp(m, parent_name=''):
|
def make_fused_mlp(m, parent_name='', is_v1_model=False):
|
||||||
"""
|
"""
|
||||||
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
|
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
|
||||||
"""
|
"""
|
||||||
if isinstance(m, LlamaMLP):
|
if isinstance(m, LlamaMLP):
|
||||||
return QuantLlamaMLP(m)
|
return QuantLlamaMLP(m, is_v1_model=is_v1_model)
|
||||||
|
|
||||||
for name, child in m.named_children():
|
for name, child in m.named_children():
|
||||||
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
|
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}", is_v1_model=is_v1_model)
|
||||||
|
|
||||||
if isinstance(child, QuantLlamaMLP):
|
if isinstance(child, QuantLlamaMLP):
|
||||||
setattr(m, name, child)
|
setattr(m, name, child)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue