add server
This commit is contained in:
parent
633c28fd25
commit
1abdc99675
|
|
@ -5,6 +5,7 @@ sentencepiece
|
||||||
safetensors
|
safetensors
|
||||||
einops
|
einops
|
||||||
colorama
|
colorama
|
||||||
|
pyzmq
|
||||||
git+https://github.com/huggingface/peft.git@70af02a2bca5a63921790036b2c9430edf4037e2
|
git+https://github.com/huggingface/peft.git@70af02a2bca5a63921790036b2c9430edf4037e2
|
||||||
git+https://github.com/huggingface/transformers.git
|
git+https://github.com/huggingface/transformers.git
|
||||||
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
from server import ModelServer
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
arg_parser = argparse.ArgumentParser()
|
||||||
|
arg_parser.add_argument('--config_path', type=str, required=True)
|
||||||
|
arg_parser.add_argument('--model_path', type=str, required=True)
|
||||||
|
arg_parser.add_argument('--lora_path', type=str, default=None)
|
||||||
|
arg_parser.add_argument('--groupsize', type=int, default=-1)
|
||||||
|
arg_parser.add_argument('--v1', action='store_true')
|
||||||
|
arg_parser.add_argument('--quant_attn', action='store_true')
|
||||||
|
arg_parser.add_argument('--port', type=int, default=5555)
|
||||||
|
arg_parser.add_argument('--pub_port', type=int, default=5556)
|
||||||
|
args = arg_parser.parse_args()
|
||||||
|
|
||||||
|
server = ModelServer(
|
||||||
|
config_path=args.config_path,
|
||||||
|
model_path=args.model_path,
|
||||||
|
lora_path=args.lora_path,
|
||||||
|
groupsize=args.groupsize,
|
||||||
|
is_v1_model=args.v1,
|
||||||
|
quant_attn=args.quant_attn,
|
||||||
|
port=args.port,
|
||||||
|
pub_port=args.pub_port)
|
||||||
|
|
||||||
|
server.run()
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .server import ModelClient, ModelServer
|
||||||
|
|
@ -0,0 +1,264 @@
|
||||||
|
from .. import autograd_4bit
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from ..autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
||||||
|
from alpaca_lora_4bit.model_attn_mlp_patch import make_quant_attn, make_fused_mlp, inject_lora_layers
|
||||||
|
import zmq
|
||||||
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||||
|
from io import BytesIO
|
||||||
|
import gc
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
def decode(output_ids, tokenizer, skip_special_tokens=True):
|
||||||
|
if skip_special_tokens:
|
||||||
|
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||||
|
reply = reply.replace(r'<|endoftext|>', '')
|
||||||
|
return reply
|
||||||
|
else:
|
||||||
|
return tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_torch_cache():
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
# Copy from text-generation-webui/modules/callbacks.py
|
||||||
|
class Stream(StoppingCriteria):
|
||||||
|
def __init__(self, callback_func=None):
|
||||||
|
self.callback_func = callback_func
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores) -> bool:
|
||||||
|
if self.callback_func is not None:
|
||||||
|
self.callback_func(input_ids[0])
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ModelServer:
|
||||||
|
|
||||||
|
def __init__(self, config_path, model_path, lora_path=None, groupsize=128, is_v1_model=False, quant_attn=False, port=5555, pub_port=5556):
|
||||||
|
self.config_path = config_path
|
||||||
|
self.model_path = model_path
|
||||||
|
self.lora_path = lora_path
|
||||||
|
self.groupsize = groupsize
|
||||||
|
self.is_v1_model = is_v1_model
|
||||||
|
self.quant_attn = quant_attn
|
||||||
|
self.port = port
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.is_generating = False
|
||||||
|
self.socket = None
|
||||||
|
self.socket_pub = None
|
||||||
|
self.pub_port = pub_port
|
||||||
|
self.topic = b'10001'
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
print("Loading {} ...".format(self.model_path))
|
||||||
|
t0 = time.time()
|
||||||
|
model, tokenizer = load_llama_model_4bit_low_ram(self.config_path, self.model_path, groupsize=self.groupsize, is_v1_model=self.is_v1_model)
|
||||||
|
|
||||||
|
if not self.quant_attn and self.lora_path is not None:
|
||||||
|
from peft import PeftModel
|
||||||
|
from ..monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
|
||||||
|
replace_peft_model_with_int4_lora_model()
|
||||||
|
model = PeftModel.from_pretrained(model, self.lora_path, device_map={'': 0}, torch_dtype=torch.float16)
|
||||||
|
print('{} Lora Applied.'.format(self.lora_path))
|
||||||
|
|
||||||
|
print('Apply half ...')
|
||||||
|
model.half()
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
|
if m.is_v1_model:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
|
m.scales = m.scales.half()
|
||||||
|
m.bias = m.bias.half()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print('Total {:.2f} GiB VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
|
||||||
|
|
||||||
|
if not self.quant_attn and self.lora_path is not None:
|
||||||
|
from ..amp_wrapper import AMPWrapper
|
||||||
|
wrapper = AMPWrapper(model)
|
||||||
|
wrapper.apply_generate()
|
||||||
|
print('AMP applied.')
|
||||||
|
|
||||||
|
if self.quant_attn:
|
||||||
|
make_quant_attn(model, is_v1_model=self.is_v1_model)
|
||||||
|
make_fused_mlp(model, is_v1_model=self.is_v1_model)
|
||||||
|
print('Quantized attention applied.')
|
||||||
|
|
||||||
|
if self.lora_path is not None:
|
||||||
|
inject_lora_layers(model, self.lora_path, device='cuda', torch_dtype=torch.float16)
|
||||||
|
|
||||||
|
self.model, self.tokenizer = model, tokenizer
|
||||||
|
print("Loaded in {:.2f} seconds.".format(time.time() - t0))
|
||||||
|
|
||||||
|
def wrap_result(self, result):
|
||||||
|
with BytesIO() as bio:
|
||||||
|
torch.save(result, bio)
|
||||||
|
return bio.getvalue()
|
||||||
|
|
||||||
|
def unwrap_result(self, result):
|
||||||
|
with BytesIO(result) as bio:
|
||||||
|
return torch.load(bio, map_location='cuda')
|
||||||
|
|
||||||
|
def send_generate_end_flag(self):
|
||||||
|
data = {
|
||||||
|
'type': 'generate_end'
|
||||||
|
}
|
||||||
|
self.socket_pub.send(self.topic + self.wrap_result(data))
|
||||||
|
|
||||||
|
def generate_thread(self, *args, **kwargs):
|
||||||
|
clear_torch_cache()
|
||||||
|
self.is_generating = True
|
||||||
|
try:
|
||||||
|
self.model.generate(*args, **kwargs)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self.is_generating = False
|
||||||
|
self.send_generate_end_flag()
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
|
def stop_generate(self):
|
||||||
|
self.is_generating = False
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.load_model()
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REP)
|
||||||
|
socket.bind("tcp://*:{}".format(self.port))
|
||||||
|
self.socket = socket
|
||||||
|
context_pub = zmq.Context()
|
||||||
|
socket_pub = context_pub.socket(zmq.PUB)
|
||||||
|
socket_pub.bind("tcp://*:{}".format(self.pub_port))
|
||||||
|
self.socket_pub = socket_pub
|
||||||
|
print('Server started at port {} and {}.'.format(self.port, self.pub_port))
|
||||||
|
'''
|
||||||
|
Message Format:
|
||||||
|
{'function': 'generate',
|
||||||
|
'args': ...,
|
||||||
|
'kwargs': ...}
|
||||||
|
'''
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for next request from client
|
||||||
|
message = socket.recv()
|
||||||
|
message = self.unwrap_result(message)
|
||||||
|
function = message['function']
|
||||||
|
if function == 'generate':
|
||||||
|
if not self.is_generating:
|
||||||
|
self.is_generating = True
|
||||||
|
args = message['args']
|
||||||
|
kwargs = message['kwargs']
|
||||||
|
input_ids = kwargs['inputs']
|
||||||
|
def func(x):
|
||||||
|
if not self.is_generating:
|
||||||
|
raise ValueError
|
||||||
|
new_tokens = len(x) - len(input_ids[0])
|
||||||
|
result = decode(x[-new_tokens:], self.tokenizer, True)
|
||||||
|
data = {
|
||||||
|
'type': 'generate',
|
||||||
|
'data': result
|
||||||
|
}
|
||||||
|
socket_pub.send(self.topic + self.wrap_result(data))
|
||||||
|
kwargs['stopping_criteria'] = StoppingCriteriaList([Stream(callback_func=func)])
|
||||||
|
t = threading.Thread(target=self.generate_thread, args=args, kwargs=kwargs)
|
||||||
|
t.setDaemon(True)
|
||||||
|
t.start()
|
||||||
|
socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'ok'}))
|
||||||
|
else:
|
||||||
|
print('Already generating.')
|
||||||
|
socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'already generating'}))
|
||||||
|
elif function == 'stop_generate':
|
||||||
|
self.stop_generate()
|
||||||
|
socket.send(self.wrap_result({'type': 'stop_generate_rsp', 'data': 'ok'}))
|
||||||
|
elif function == 'test':
|
||||||
|
print('test ok.')
|
||||||
|
self.socket.send(self.wrap_result(
|
||||||
|
{
|
||||||
|
'type': 'test',
|
||||||
|
'data': 'test ok.'
|
||||||
|
}
|
||||||
|
))
|
||||||
|
elif function == 'exit':
|
||||||
|
socket.send(self.wrap_result({'type': 'exit_rsp', 'data': 'ok'}))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
socket.send(self.wrap_result({'type': 'rsp', 'data': 'no function'}))
|
||||||
|
raise ValueError('Unknown function {}'.format(function))
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
raise
|
||||||
|
print('Server stopped.')
|
||||||
|
|
||||||
|
|
||||||
|
class ModelClient:
|
||||||
|
|
||||||
|
def __init__(self, port=5555, port_sub=5556):
|
||||||
|
self.port = port
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.REQ)
|
||||||
|
self.socket.connect("tcp://localhost:{}".format(self.port))
|
||||||
|
self.socket_sub = self.context.socket(zmq.SUB)
|
||||||
|
self.topic = b'10001'
|
||||||
|
self.socket_sub.setsockopt(zmq.SUBSCRIBE, self.topic)
|
||||||
|
self.socket_sub.connect("tcp://localhost:{}".format(port_sub))
|
||||||
|
self.callback_func = None
|
||||||
|
|
||||||
|
def wrap_result(self, result):
|
||||||
|
with BytesIO() as bio:
|
||||||
|
torch.save(result, bio)
|
||||||
|
return bio.getvalue()
|
||||||
|
|
||||||
|
def unwrap_result(self, result):
|
||||||
|
with BytesIO(result) as bio:
|
||||||
|
return torch.load(bio, map_location='cuda')
|
||||||
|
|
||||||
|
def recieve_thread(self):
|
||||||
|
while True:
|
||||||
|
message = self.socket_sub.recv()
|
||||||
|
message = message[len(self.topic):]
|
||||||
|
message = self.unwrap_result(message)
|
||||||
|
if message['type'] == 'generate':
|
||||||
|
if self.callback_func is not None:
|
||||||
|
self.callback_func(message['data'], is_end=False)
|
||||||
|
elif message['type'] == 'generate_end':
|
||||||
|
if self.callback_func is not None:
|
||||||
|
self.callback_func(None, is_end=True)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(message)
|
||||||
|
break
|
||||||
|
print('receive completed.')
|
||||||
|
|
||||||
|
def start_recieving(self):
|
||||||
|
t = threading.Thread(target=self.recieve_thread)
|
||||||
|
t.setDaemon(True)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
def generate(self, *args, **kwargs):
|
||||||
|
data = {
|
||||||
|
'function': 'generate',
|
||||||
|
'args': args,
|
||||||
|
'kwargs': kwargs
|
||||||
|
}
|
||||||
|
self.socket.send(self.wrap_result(data))
|
||||||
|
result = self.socket.recv()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
data = {
|
||||||
|
'function': 'stop_generate'
|
||||||
|
}
|
||||||
|
self.socket.send(self.wrap_result(data))
|
||||||
|
result = self.socket.recv()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
data = {
|
||||||
|
'function': 'test'
|
||||||
|
}
|
||||||
|
self.socket.send(self.wrap_result(data))
|
||||||
|
result = self.socket.recv()
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
from server import ModelClient
|
||||||
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
|
def load_model_llama(*args, **kwargs):
|
||||||
|
config_path = '../llama-13b-4bit/'
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(config_path)
|
||||||
|
tokenizer.truncation_side = 'left'
|
||||||
|
model = ModelClient(port=5555, port_sub=5556)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
patch_encode_func = True
|
||||||
|
|
||||||
|
# Monkey Patch
|
||||||
|
from modules import models
|
||||||
|
from modules import shared
|
||||||
|
models.load_model = load_model_llama
|
||||||
|
shared.args.model = 'llama-13b-4bit'
|
||||||
|
shared.settings['name1'] = 'You'
|
||||||
|
shared.settings['name2'] = 'Assistant'
|
||||||
|
shared.settings['chat_prompt_size_max'] = 2048
|
||||||
|
shared.settings['chat_prompt_size'] = 2048
|
||||||
|
|
||||||
|
if patch_encode_func:
|
||||||
|
from modules import text_generation
|
||||||
|
text_generation.encode_old = text_generation.encode
|
||||||
|
def encode_patched(*args, **kwargs):
|
||||||
|
input_ids = text_generation.encode_old(*args, **kwargs)
|
||||||
|
if input_ids[0,0] == 0:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
return input_ids
|
||||||
|
text_generation.encode = encode_patched
|
||||||
|
print('Encode Function Patched.')
|
||||||
|
|
||||||
|
print('Monkey Patch Completed.')
|
||||||
|
|
||||||
|
# Apply Generate Monkey Patch
|
||||||
|
import generate_monkey_patch
|
||||||
|
|
@ -0,0 +1,213 @@
|
||||||
|
import modules.text_generation
|
||||||
|
from modules.text_generation import *
|
||||||
|
from modules.callbacks import _SentinelTokenStoppingCriteria
|
||||||
|
|
||||||
|
def generate_reply_patched(question, state, eos_token=None, stopping_strings=[]):
|
||||||
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
|
print("No model is loaded! Select one in the Model tab.")
|
||||||
|
yield formatted_outputs(question, shared.model_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
clear_torch_cache()
|
||||||
|
seed = set_manual_seed(state['seed'])
|
||||||
|
shared.stop_everything = False
|
||||||
|
generate_params = get_generate_params(state)
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Preparing the input
|
||||||
|
original_question = question
|
||||||
|
if not shared.is_chat():
|
||||||
|
question = apply_extensions('input', question)
|
||||||
|
|
||||||
|
# If the model is not on transformers, handle it separately and end this
|
||||||
|
# function call earlier.
|
||||||
|
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||||
|
if shared.args.verbose:
|
||||||
|
print(f'\n\n{question}\n--------------------\n')
|
||||||
|
|
||||||
|
try:
|
||||||
|
if shared.args.no_stream:
|
||||||
|
reply = shared.model.generate(context=question, **generate_params)
|
||||||
|
output = original_question + reply
|
||||||
|
if not shared.is_chat():
|
||||||
|
reply = original_question + apply_extensions('output', reply)
|
||||||
|
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
else:
|
||||||
|
if not shared.is_chat():
|
||||||
|
yield formatted_outputs(question, shared.model_name)
|
||||||
|
|
||||||
|
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||||
|
output = original_question + reply
|
||||||
|
if not shared.is_chat():
|
||||||
|
reply = original_question + apply_extensions('output', reply)
|
||||||
|
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
t1 = time.time()
|
||||||
|
original_tokens = len(encode(original_question)[0])
|
||||||
|
new_tokens = len(encode(output)[0]) - original_tokens
|
||||||
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Encode the input
|
||||||
|
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||||
|
output = input_ids[0]
|
||||||
|
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||||
|
if shared.args.verbose:
|
||||||
|
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
|
||||||
|
|
||||||
|
# Find the eos tokens
|
||||||
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
|
if eos_token is not None:
|
||||||
|
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||||
|
|
||||||
|
# Create the StoppingCriteriaList with the stopping strings
|
||||||
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
|
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||||
|
if type(st) is list and len(st) > 0:
|
||||||
|
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
|
||||||
|
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update generate_params with the eos token and the stopping strings
|
||||||
|
if shared.args.flexgen:
|
||||||
|
generate_params['stop'] = eos_token_ids[-1]
|
||||||
|
else:
|
||||||
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
|
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||||
|
|
||||||
|
# Add the encoded tokens to generate_params
|
||||||
|
if shared.soft_prompt:
|
||||||
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
|
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
||||||
|
original_input_ids = input_ids
|
||||||
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
generate_params.update({'inputs': filler_input_ids})
|
||||||
|
else:
|
||||||
|
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||||
|
original_input_ids = input_ids
|
||||||
|
generate_params.update({'inputs': input_ids})
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate the entire reply at once.
|
||||||
|
if shared.args.no_stream:
|
||||||
|
with torch.no_grad():
|
||||||
|
output = shared.model.generate(**generate_params)[0]
|
||||||
|
if cuda:
|
||||||
|
output = output.cuda()
|
||||||
|
|
||||||
|
if shared.soft_prompt:
|
||||||
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
|
||||||
|
new_tokens = len(output) - len(input_ids[0])
|
||||||
|
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||||
|
if not shared.is_chat():
|
||||||
|
reply = original_question + apply_extensions('output', reply)
|
||||||
|
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
# Stream the reply 1 token at a time.
|
||||||
|
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||||
|
elif not shared.args.flexgen:
|
||||||
|
|
||||||
|
# def generate_with_callback(callback=None, **kwargs):
|
||||||
|
# kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||||
|
# clear_torch_cache()
|
||||||
|
# with torch.no_grad():
|
||||||
|
# shared.model.generate(**kwargs)
|
||||||
|
|
||||||
|
# def generate_with_streaming(**kwargs):
|
||||||
|
# return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||||
|
|
||||||
|
# if not shared.is_chat():
|
||||||
|
# yield formatted_outputs(original_question, shared.model_name)
|
||||||
|
|
||||||
|
# with generate_with_streaming(**generate_params) as generator:
|
||||||
|
# for output in generator:
|
||||||
|
# if shared.soft_prompt:
|
||||||
|
# output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
|
||||||
|
# new_tokens = len(output) - len(input_ids[0])
|
||||||
|
# reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||||
|
# if not shared.is_chat():
|
||||||
|
# reply = original_question + apply_extensions('output', reply)
|
||||||
|
|
||||||
|
# if output[-1] in eos_token_ids:
|
||||||
|
# break
|
||||||
|
|
||||||
|
# yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
from queue import Queue
|
||||||
|
queue = Queue()
|
||||||
|
def callback_func(x, is_end=False):
|
||||||
|
if not is_end:
|
||||||
|
queue.put(x)
|
||||||
|
else:
|
||||||
|
queue.put(None)
|
||||||
|
|
||||||
|
# remove stopping_criteria
|
||||||
|
generate_params.pop('stopping_criteria')
|
||||||
|
|
||||||
|
shared.model.callback_func = callback_func
|
||||||
|
shared.model.generate(**generate_params)
|
||||||
|
shared.model.start_recieving()
|
||||||
|
|
||||||
|
token_count = 0
|
||||||
|
while True:
|
||||||
|
reply = queue.get()
|
||||||
|
if reply is None:
|
||||||
|
break
|
||||||
|
token_count += 1
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
|
else:
|
||||||
|
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||||
|
clear_torch_cache()
|
||||||
|
with torch.no_grad():
|
||||||
|
output = shared.model.generate(**generate_params)[0]
|
||||||
|
|
||||||
|
if shared.soft_prompt:
|
||||||
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
|
||||||
|
new_tokens = len(output) - len(original_input_ids[0])
|
||||||
|
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||||
|
if not shared.is_chat():
|
||||||
|
reply = original_question + apply_extensions('output', reply)
|
||||||
|
|
||||||
|
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||||
|
break
|
||||||
|
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||||
|
if shared.soft_prompt:
|
||||||
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
generate_params.update({'inputs': filler_input_ids})
|
||||||
|
else:
|
||||||
|
generate_params.update({'inputs': input_ids})
|
||||||
|
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
t1 = time.time()
|
||||||
|
try:
|
||||||
|
shared.model.stop()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
original_tokens = len(original_input_ids[0])
|
||||||
|
new_tokens = token_count
|
||||||
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
|
return
|
||||||
|
|
||||||
|
modules.text_generation.generate_reply_old = modules.text_generation.generate_reply
|
||||||
|
modules.text_generation.generate_reply = generate_reply_patched
|
||||||
|
print('Generate Patch Applied')
|
||||||
Loading…
Reference in New Issue