optimize mem usage
This commit is contained in:
parent
de3c91834e
commit
eb442494d1
|
|
@ -85,10 +85,27 @@ def make_quant_attn(model):
|
|||
v_proj = m.v_proj
|
||||
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
del q_proj.qweight
|
||||
del k_proj.qweight
|
||||
del v_proj.qweight
|
||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||
del q_proj.qzeros
|
||||
del k_proj.qzeros
|
||||
del v_proj.qzeros
|
||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
del q_proj.scales
|
||||
del k_proj.scales
|
||||
del v_proj.scales
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
del q_proj.g_idx
|
||||
del k_proj.g_idx
|
||||
del v_proj.g_idx
|
||||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||
if q_proj.bias is not None:
|
||||
del q_proj.bias
|
||||
del k_proj.bias
|
||||
del v_proj.bias
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
qkv_layer = Autograd4bitQuantLinear(q_proj.in_features,
|
||||
q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
||||
|
|
@ -124,10 +141,22 @@ class QuantLlamaMLP(nn.Module):
|
|||
up_proj = old_module.up_proj
|
||||
|
||||
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
|
||||
del gate_proj.qweight
|
||||
del up_proj.qweight
|
||||
qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1)
|
||||
del gate_proj.qzeros
|
||||
del up_proj.qzeros
|
||||
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1)
|
||||
del gate_proj.scales
|
||||
del up_proj.scales
|
||||
g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0)
|
||||
del gate_proj.g_idx
|
||||
del up_proj.g_idx
|
||||
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
|
||||
if gate_proj.bias is not None:
|
||||
del gate_proj.bias
|
||||
del up_proj.bias
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features,
|
||||
gate_proj.out_features + up_proj.out_features,
|
||||
|
|
|
|||
Loading…
Reference in New Issue