Update autograd_4bit.py
This commit is contained in:
parent
fecce0e1a5
commit
6f4bbb40a9
|
|
@ -115,11 +115,6 @@ def load_llama_model_4bit_low_ram(config_path, model_path):
|
||||||
|
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
config = LLaMAConfig.from_pretrained(config_path)
|
config = LLaMAConfig.from_pretrained(config_path)
|
||||||
def noop(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
torch.nn.init.kaiming_uniform_ = noop
|
|
||||||
torch.nn.init.uniform_ = noop
|
|
||||||
torch.nn.init.normal_ = noop
|
|
||||||
torch.set_default_dtype(torch.half)
|
torch.set_default_dtype(torch.half)
|
||||||
transformers.modeling_utils._init_weights = False
|
transformers.modeling_utils._init_weights = False
|
||||||
torch.set_default_dtype(torch.half)
|
torch.set_default_dtype(torch.half)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue