From 42ef3484a966e0e1e7487858a590b820bc02762b Mon Sep 17 00:00:00 2001 From: John Smith Date: Wed, 26 Apr 2023 14:38:57 +0800 Subject: [PATCH] fix _SentinelTokenStoppingCriteria --- server/__init__.py | 2 +- server/server.py | 22 ++++++++++++ .../generate_monkey_patch.py | 35 ++----------------- 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/server/__init__.py b/server/__init__.py index d01cea1..c934e8e 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -1 +1 @@ -from .server import ModelClient, ModelServer +from .server import ModelClient, ModelServer, _SentinelTokenStoppingCriteria diff --git a/server/server.py b/server/server.py index de40fc5..7fbe833 100644 --- a/server/server.py +++ b/server/server.py @@ -24,6 +24,28 @@ def clear_torch_cache(): torch.cuda.empty_cache() +# Copied from https://github.com/PygmalionAI/gradio-ui/ +class _SentinelTokenStoppingCriteria(StoppingCriteria): + + def __init__(self, sentinel_token_ids: list, starting_idx: int): + StoppingCriteria.__init__(self) + self.sentinel_token_ids = sentinel_token_ids + self.starting_idx = starting_idx + + def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: + for sample in input_ids: + trimmed_sample = sample[self.starting_idx:] + + for i in range(len(self.sentinel_token_ids)): + # Can't unfold, output is still too tiny. Skip. + if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]: + continue + for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1): + if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)): + return True + return False + + # Copy from text-generation-webui/modules/callbacks.py class Stream(StoppingCriteria): def __init__(self, callback_func=None): diff --git a/text-generation-webui/generate_monkey_patch.py b/text-generation-webui/generate_monkey_patch.py index efe52fa..d3c59b1 100644 --- a/text-generation-webui/generate_monkey_patch.py +++ b/text-generation-webui/generate_monkey_patch.py @@ -1,6 +1,6 @@ import modules.text_generation from modules.text_generation import * -from modules.callbacks import _SentinelTokenStoppingCriteria +from alpaca_lora_4bit.server import _SentinelTokenStoppingCriteria def generate_reply_patched(question, state, eos_token=None, stopping_strings=[]): if shared.model_name == 'None' or shared.model is None: @@ -115,34 +115,8 @@ def generate_reply_patched(question, state, eos_token=None, stopping_strings=[]) # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator. elif not shared.args.flexgen: - - # def generate_with_callback(callback=None, **kwargs): - # kwargs['stopping_criteria'].append(Stream(callback_func=callback)) - # clear_torch_cache() - # with torch.no_grad(): - # shared.model.generate(**kwargs) - - # def generate_with_streaming(**kwargs): - # return Iteratorize(generate_with_callback, kwargs, callback=None) - - # if not shared.is_chat(): - # yield formatted_outputs(original_question, shared.model_name) - - # with generate_with_streaming(**generate_params) as generator: - # for output in generator: - # if shared.soft_prompt: - # output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - - # new_tokens = len(output) - len(input_ids[0]) - # reply = decode(output[-new_tokens:], state['skip_special_tokens']) - # if not shared.is_chat(): - # reply = original_question + apply_extensions('output', reply) - - # if output[-1] in eos_token_ids: - # break - - # yield formatted_outputs(reply, shared.model_name) - + + # Repalced Original with another socket server from queue import Queue queue = Queue() def callback_func(x, is_end=False): @@ -151,9 +125,6 @@ def generate_reply_patched(question, state, eos_token=None, stopping_strings=[]) else: queue.put(None) - # remove stopping_criteria - generate_params.pop('stopping_criteria') - shared.model.callback_func = callback_func shared.model.generate(**generate_params) shared.model.start_recieving()