add patch for encode function to remove eos token at the beginning of left side
This commit is contained in:
parent
085d9556f9
commit
9a02a88fb8
|
|
@ -5,6 +5,8 @@ from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from peft.tuners.lora import Linear4bitLt
|
from peft.tuners.lora import Linear4bitLt
|
||||||
|
|
||||||
|
patch_encode_func = False
|
||||||
|
|
||||||
def load_model_llama(*args, **kwargs):
|
def load_model_llama(*args, **kwargs):
|
||||||
|
|
||||||
config_path = '../llama-13b-4bit/'
|
config_path = '../llama-13b-4bit/'
|
||||||
|
|
@ -41,4 +43,15 @@ shared.settings['name2'] = 'Assistant'
|
||||||
shared.settings['chat_prompt_size_max'] = 2048
|
shared.settings['chat_prompt_size_max'] = 2048
|
||||||
shared.settings['chat_prompt_size'] = 2048
|
shared.settings['chat_prompt_size'] = 2048
|
||||||
|
|
||||||
|
if patch_encode_func:
|
||||||
|
from modules import text_generation
|
||||||
|
text_generation.encode_old = text_generation.encode
|
||||||
|
def encode_patched(*args, **kwargs):
|
||||||
|
input_ids = text_generation.encode_old(*args, **kwargs)
|
||||||
|
if input_ids[0,0] == 0:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
return input_ids
|
||||||
|
text_generation.encode = encode_patched
|
||||||
|
print('Encode Function Patched.')
|
||||||
|
|
||||||
print('Monkey Patch Completed.')
|
print('Monkey Patch Completed.')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue