diff --git a/text-generation-webui/custom_monkey_patch.py b/text-generation-webui/custom_monkey_patch.py index 6f586e3..0f4d370 100644 --- a/text-generation-webui/custom_monkey_patch.py +++ b/text-generation-webui/custom_monkey_patch.py @@ -5,6 +5,8 @@ from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear from peft import PeftModel from peft.tuners.lora import Linear4bitLt +patch_encode_func = False + def load_model_llama(*args, **kwargs): config_path = '../llama-13b-4bit/' @@ -41,4 +43,15 @@ shared.settings['name2'] = 'Assistant' shared.settings['chat_prompt_size_max'] = 2048 shared.settings['chat_prompt_size'] = 2048 +if patch_encode_func: + from modules import text_generation + text_generation.encode_old = text_generation.encode + def encode_patched(*args, **kwargs): + input_ids = text_generation.encode_old(*args, **kwargs) + if input_ids[0,0] == 0: + input_ids = input_ids[:, 1:] + return input_ids + text_generation.encode = encode_patched + print('Encode Function Patched.') + print('Monkey Patch Completed.')