diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index e010bcc..0e08194 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -40,7 +40,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros): assert qweight.shape[0] * 8 == x.shape[-1] outshape = x.shape[:-1] + (qweight.shape[1],) x = x.reshape(-1, x.shape[-1]) - y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) + y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device) dtype = x.dtype x = x.half() quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros) @@ -114,7 +114,7 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None): if use_new: if auto_switch: if np.prod(x.shape[:-1]) > auto_switch_thd: - output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half()) + output = _matmul4bit_v1_recons(x.half(), qweight, scales.half(), zeros.half()) else: output = _matmul4bit_v1(x, qweight, scales, zeros) else: diff --git a/model_attn_mlp_patch.py b/model_attn_mlp_patch.py index 4320f31..5a0c91b 100644 --- a/model_attn_mlp_patch.py +++ b/model_attn_mlp_patch.py @@ -70,7 +70,7 @@ class QuantLlamaAttention(nn.Module): return attn_output, attn_weights, past_key_value -def make_quant_attn(model): +def make_quant_attn(model, is_v1_model=False): """ Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. """ @@ -84,38 +84,67 @@ def make_quant_attn(model): k_proj = m.k_proj v_proj = m.v_proj - qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - del q_proj.qweight - del k_proj.qweight - del v_proj.qweight - qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - del q_proj.qzeros - del k_proj.qzeros - del v_proj.qzeros - scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) - del q_proj.scales - del k_proj.scales - del v_proj.scales - g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) - del q_proj.g_idx - del k_proj.g_idx - del v_proj.g_idx - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None - if q_proj.bias is not None: - del q_proj.bias - del k_proj.bias - del v_proj.bias - torch.cuda.empty_cache() + if not is_v1_model: + qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + del q_proj.qweight + del k_proj.qweight + del v_proj.qweight + qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) + del q_proj.qzeros + del k_proj.qzeros + del v_proj.qzeros + scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + del q_proj.scales + del k_proj.scales + del v_proj.scales + g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) + del q_proj.g_idx + del k_proj.g_idx + del v_proj.g_idx + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + if q_proj.bias is not None: + del q_proj.bias + del k_proj.bias + del v_proj.bias + torch.cuda.empty_cache() - qkv_layer = Autograd4bitQuantLinear(q_proj.in_features, - q_proj.out_features + k_proj.out_features + v_proj.out_features, - q_proj.groupsize, - is_v1_model=False) - qkv_layer.qweight = qweights - qkv_layer.qzeros = qzeros - qkv_layer.scales = scales - qkv_layer.g_idx = g_idx - qkv_layer.bias = bias + qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features, + out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features, + groupsize=q_proj.groupsize, + is_v1_model=False) + qkv_layer.qweight = qweights + qkv_layer.qzeros = qzeros + qkv_layer.scales = scales + qkv_layer.g_idx = g_idx + qkv_layer.bias = bias + else: + qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + del q_proj.qweight + del k_proj.qweight + del v_proj.qweight + zeros = torch.cat([q_proj.zeros, k_proj.zeros, v_proj.zeros], dim=0) + del q_proj.zeros + del k_proj.zeros + del v_proj.zeros + scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) + del q_proj.scales + del k_proj.scales + del v_proj.scales + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + if q_proj.bias is not None: + del q_proj.bias + del k_proj.bias + del v_proj.bias + torch.cuda.empty_cache() + + qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features, + out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features, + groupsize=-1, + is_v1_model=True) + qkv_layer.qweight = qweights + qkv_layer.zeros = zeros + qkv_layer.scales = scales + qkv_layer.bias = bias attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) @@ -134,39 +163,64 @@ def make_quant_attn(model): class QuantLlamaMLP(nn.Module): - def __init__(self, old_module): + def __init__(self, old_module, is_v1_model=False): super().__init__() gate_proj = old_module.gate_proj up_proj = old_module.up_proj - qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) - del gate_proj.qweight - del up_proj.qweight - qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1) - del gate_proj.qzeros - del up_proj.qzeros - scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1) - del gate_proj.scales - del up_proj.scales - g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0) - del gate_proj.g_idx - del up_proj.g_idx - bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None - if gate_proj.bias is not None: - del gate_proj.bias - del up_proj.bias - torch.cuda.empty_cache() + if not is_v1_model: + qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) + del gate_proj.qweight + del up_proj.qweight + qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1) + del gate_proj.qzeros + del up_proj.qzeros + scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1) + del gate_proj.scales + del up_proj.scales + g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0) + del gate_proj.g_idx + del up_proj.g_idx + bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None + if gate_proj.bias is not None: + del gate_proj.bias + del up_proj.bias + torch.cuda.empty_cache() - self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features, - gate_proj.out_features + up_proj.out_features, - gate_proj.groupsize, - is_v1_model=False) - self.gate_up_proj.qweight = qweights - self.gate_up_proj.qzeros = qzeros - self.gate_up_proj.scales = scales - self.gate_up_proj.g_idx = g_idx - self.gate_up_proj.bias = bias + self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features, + out_features=gate_proj.out_features + up_proj.out_features, + groupsize=gate_proj.groupsize, + is_v1_model=False) + self.gate_up_proj.qweight = qweights + self.gate_up_proj.qzeros = qzeros + self.gate_up_proj.scales = scales + self.gate_up_proj.g_idx = g_idx + self.gate_up_proj.bias = bias + else: + qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) + del gate_proj.qweight + del up_proj.qweight + zeros = torch.cat([gate_proj.zeros, up_proj.zeros], dim=0) + del gate_proj.zeros + del up_proj.zeros + scales = torch.cat([gate_proj.scales, up_proj.scales], dim=0) + del gate_proj.scales + del up_proj.scales + bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None + if gate_proj.bias is not None: + del gate_proj.bias + del up_proj.bias + torch.cuda.empty_cache() + + self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features, + out_features=gate_proj.out_features + up_proj.out_features, + groupsize=gate_proj.groupsize, + is_v1_model=True) + self.gate_up_proj.qweight = qweights + self.gate_up_proj.zeros = zeros + self.gate_up_proj.scales = scales + self.gate_up_proj.bias = bias self.down_proj = old_module.down_proj self.act_fn = old_module.act_fn @@ -178,15 +232,15 @@ class QuantLlamaMLP(nn.Module): return self.down_proj(self.act_fn(gate) * up) -def make_fused_mlp(m, parent_name=''): +def make_fused_mlp(m, parent_name='', is_v1_model=False): """ Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. """ if isinstance(m, LlamaMLP): - return QuantLlamaMLP(m) + return QuantLlamaMLP(m, is_v1_model=is_v1_model) for name, child in m.named_children(): - child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") + child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}", is_v1_model=is_v1_model) if isinstance(child, QuantLlamaMLP): setattr(m, name, child)