add patch for encode function to remove eos token at the beginning of left side
This commit is contained in:
parent
085d9556f9
commit
9a02a88fb8
|
|
@ -5,6 +5,8 @@ from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
|||
from peft import PeftModel
|
||||
from peft.tuners.lora import Linear4bitLt
|
||||
|
||||
patch_encode_func = False
|
||||
|
||||
def load_model_llama(*args, **kwargs):
|
||||
|
||||
config_path = '../llama-13b-4bit/'
|
||||
|
|
@ -41,4 +43,15 @@ shared.settings['name2'] = 'Assistant'
|
|||
shared.settings['chat_prompt_size_max'] = 2048
|
||||
shared.settings['chat_prompt_size'] = 2048
|
||||
|
||||
if patch_encode_func:
|
||||
from modules import text_generation
|
||||
text_generation.encode_old = text_generation.encode
|
||||
def encode_patched(*args, **kwargs):
|
||||
input_ids = text_generation.encode_old(*args, **kwargs)
|
||||
if input_ids[0,0] == 0:
|
||||
input_ids = input_ids[:, 1:]
|
||||
return input_ids
|
||||
text_generation.encode = encode_patched
|
||||
print('Encode Function Patched.')
|
||||
|
||||
print('Monkey Patch Completed.')
|
||||
|
|
|
|||
Loading…
Reference in New Issue