optimize mem usage

This commit is contained in:
John Smith 2023-04-22 16:35:18 +08:00
parent de3c91834e
commit eb442494d1
1 changed files with 29 additions and 0 deletions

View File

@ -85,10 +85,27 @@ def make_quant_attn(model):
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
del q_proj.qweight
del k_proj.qweight
del v_proj.qweight
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
del q_proj.qzeros
del k_proj.qzeros
del v_proj.qzeros
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
del q_proj.scales
del k_proj.scales
del v_proj.scales
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
del q_proj.g_idx
del k_proj.g_idx
del v_proj.g_idx
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if q_proj.bias is not None:
del q_proj.bias
del k_proj.bias
del v_proj.bias
torch.cuda.empty_cache()
qkv_layer = Autograd4bitQuantLinear(q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
@ -124,10 +141,22 @@ class QuantLlamaMLP(nn.Module):
up_proj = old_module.up_proj
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
del gate_proj.qweight
del up_proj.qweight
qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1)
del gate_proj.qzeros
del up_proj.qzeros
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1)
del gate_proj.scales
del up_proj.scales
g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0)
del gate_proj.g_idx
del up_proj.g_idx
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
if gate_proj.bias is not None:
del gate_proj.bias
del up_proj.bias
torch.cuda.empty_cache()
self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features,
gate_proj.out_features + up_proj.out_features,