add patch for encode function to remove eos token at the beginning of left side

This commit is contained in:
John Smith 2023-04-06 12:56:27 +08:00
parent 085d9556f9
commit 9a02a88fb8
1 changed files with 13 additions and 0 deletions

View File

@ -5,6 +5,8 @@ from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
from peft import PeftModel
from peft.tuners.lora import Linear4bitLt
patch_encode_func = False
def load_model_llama(*args, **kwargs):
config_path = '../llama-13b-4bit/'
@ -41,4 +43,15 @@ shared.settings['name2'] = 'Assistant'
shared.settings['chat_prompt_size_max'] = 2048
shared.settings['chat_prompt_size'] = 2048
if patch_encode_func:
from modules import text_generation
text_generation.encode_old = text_generation.encode
def encode_patched(*args, **kwargs):
input_ids = text_generation.encode_old(*args, **kwargs)
if input_ids[0,0] == 0:
input_ids = input_ids[:, 1:]
return input_ids
text_generation.encode = encode_patched
print('Encode Function Patched.')
print('Monkey Patch Completed.')