add server
This commit is contained in:
parent
633c28fd25
commit
1abdc99675
|
|
@ -5,6 +5,7 @@ sentencepiece
|
|||
safetensors
|
||||
einops
|
||||
colorama
|
||||
pyzmq
|
||||
git+https://github.com/huggingface/peft.git@70af02a2bca5a63921790036b2c9430edf4037e2
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
from server import ModelServer
|
||||
import argparse
|
||||
|
||||
if __name__ == '__main__':
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--config_path', type=str, required=True)
|
||||
arg_parser.add_argument('--model_path', type=str, required=True)
|
||||
arg_parser.add_argument('--lora_path', type=str, default=None)
|
||||
arg_parser.add_argument('--groupsize', type=int, default=-1)
|
||||
arg_parser.add_argument('--v1', action='store_true')
|
||||
arg_parser.add_argument('--quant_attn', action='store_true')
|
||||
arg_parser.add_argument('--port', type=int, default=5555)
|
||||
arg_parser.add_argument('--pub_port', type=int, default=5556)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
server = ModelServer(
|
||||
config_path=args.config_path,
|
||||
model_path=args.model_path,
|
||||
lora_path=args.lora_path,
|
||||
groupsize=args.groupsize,
|
||||
is_v1_model=args.v1,
|
||||
quant_attn=args.quant_attn,
|
||||
port=args.port,
|
||||
pub_port=args.pub_port)
|
||||
|
||||
server.run()
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .server import ModelClient, ModelServer
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
from .. import autograd_4bit
|
||||
import time
|
||||
import torch
|
||||
from ..autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
||||
from alpaca_lora_4bit.model_attn_mlp_patch import make_quant_attn, make_fused_mlp, inject_lora_layers
|
||||
import zmq
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
from io import BytesIO
|
||||
import gc
|
||||
import threading
|
||||
|
||||
|
||||
def decode(output_ids, tokenizer, skip_special_tokens=True):
|
||||
if skip_special_tokens:
|
||||
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
return reply
|
||||
else:
|
||||
return tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Copy from text-generation-webui/modules/callbacks.py
|
||||
class Stream(StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
||||
def __call__(self, input_ids, scores) -> bool:
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(input_ids[0])
|
||||
return False
|
||||
|
||||
|
||||
class ModelServer:
|
||||
|
||||
def __init__(self, config_path, model_path, lora_path=None, groupsize=128, is_v1_model=False, quant_attn=False, port=5555, pub_port=5556):
|
||||
self.config_path = config_path
|
||||
self.model_path = model_path
|
||||
self.lora_path = lora_path
|
||||
self.groupsize = groupsize
|
||||
self.is_v1_model = is_v1_model
|
||||
self.quant_attn = quant_attn
|
||||
self.port = port
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.is_generating = False
|
||||
self.socket = None
|
||||
self.socket_pub = None
|
||||
self.pub_port = pub_port
|
||||
self.topic = b'10001'
|
||||
|
||||
def load_model(self):
|
||||
print("Loading {} ...".format(self.model_path))
|
||||
t0 = time.time()
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(self.config_path, self.model_path, groupsize=self.groupsize, is_v1_model=self.is_v1_model)
|
||||
|
||||
if not self.quant_attn and self.lora_path is not None:
|
||||
from peft import PeftModel
|
||||
from ..monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
|
||||
replace_peft_model_with_int4_lora_model()
|
||||
model = PeftModel.from_pretrained(model, self.lora_path, device_map={'': 0}, torch_dtype=torch.float16)
|
||||
print('{} Lora Applied.'.format(self.lora_path))
|
||||
|
||||
print('Apply half ...')
|
||||
model.half()
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, Autograd4bitQuantLinear):
|
||||
if m.is_v1_model:
|
||||
m.zeros = m.zeros.half()
|
||||
m.scales = m.scales.half()
|
||||
m.bias = m.bias.half()
|
||||
torch.cuda.empty_cache()
|
||||
print('Total {:.2f} GiB VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
|
||||
|
||||
if not self.quant_attn and self.lora_path is not None:
|
||||
from ..amp_wrapper import AMPWrapper
|
||||
wrapper = AMPWrapper(model)
|
||||
wrapper.apply_generate()
|
||||
print('AMP applied.')
|
||||
|
||||
if self.quant_attn:
|
||||
make_quant_attn(model, is_v1_model=self.is_v1_model)
|
||||
make_fused_mlp(model, is_v1_model=self.is_v1_model)
|
||||
print('Quantized attention applied.')
|
||||
|
||||
if self.lora_path is not None:
|
||||
inject_lora_layers(model, self.lora_path, device='cuda', torch_dtype=torch.float16)
|
||||
|
||||
self.model, self.tokenizer = model, tokenizer
|
||||
print("Loaded in {:.2f} seconds.".format(time.time() - t0))
|
||||
|
||||
def wrap_result(self, result):
|
||||
with BytesIO() as bio:
|
||||
torch.save(result, bio)
|
||||
return bio.getvalue()
|
||||
|
||||
def unwrap_result(self, result):
|
||||
with BytesIO(result) as bio:
|
||||
return torch.load(bio, map_location='cuda')
|
||||
|
||||
def send_generate_end_flag(self):
|
||||
data = {
|
||||
'type': 'generate_end'
|
||||
}
|
||||
self.socket_pub.send(self.topic + self.wrap_result(data))
|
||||
|
||||
def generate_thread(self, *args, **kwargs):
|
||||
clear_torch_cache()
|
||||
self.is_generating = True
|
||||
try:
|
||||
self.model.generate(*args, **kwargs)
|
||||
except ValueError:
|
||||
pass
|
||||
finally:
|
||||
self.is_generating = False
|
||||
self.send_generate_end_flag()
|
||||
clear_torch_cache()
|
||||
|
||||
def stop_generate(self):
|
||||
self.is_generating = False
|
||||
|
||||
def run(self):
|
||||
self.load_model()
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind("tcp://*:{}".format(self.port))
|
||||
self.socket = socket
|
||||
context_pub = zmq.Context()
|
||||
socket_pub = context_pub.socket(zmq.PUB)
|
||||
socket_pub.bind("tcp://*:{}".format(self.pub_port))
|
||||
self.socket_pub = socket_pub
|
||||
print('Server started at port {} and {}.'.format(self.port, self.pub_port))
|
||||
'''
|
||||
Message Format:
|
||||
{'function': 'generate',
|
||||
'args': ...,
|
||||
'kwargs': ...}
|
||||
'''
|
||||
while True:
|
||||
try:
|
||||
# Wait for next request from client
|
||||
message = socket.recv()
|
||||
message = self.unwrap_result(message)
|
||||
function = message['function']
|
||||
if function == 'generate':
|
||||
if not self.is_generating:
|
||||
self.is_generating = True
|
||||
args = message['args']
|
||||
kwargs = message['kwargs']
|
||||
input_ids = kwargs['inputs']
|
||||
def func(x):
|
||||
if not self.is_generating:
|
||||
raise ValueError
|
||||
new_tokens = len(x) - len(input_ids[0])
|
||||
result = decode(x[-new_tokens:], self.tokenizer, True)
|
||||
data = {
|
||||
'type': 'generate',
|
||||
'data': result
|
||||
}
|
||||
socket_pub.send(self.topic + self.wrap_result(data))
|
||||
kwargs['stopping_criteria'] = StoppingCriteriaList([Stream(callback_func=func)])
|
||||
t = threading.Thread(target=self.generate_thread, args=args, kwargs=kwargs)
|
||||
t.setDaemon(True)
|
||||
t.start()
|
||||
socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'ok'}))
|
||||
else:
|
||||
print('Already generating.')
|
||||
socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'already generating'}))
|
||||
elif function == 'stop_generate':
|
||||
self.stop_generate()
|
||||
socket.send(self.wrap_result({'type': 'stop_generate_rsp', 'data': 'ok'}))
|
||||
elif function == 'test':
|
||||
print('test ok.')
|
||||
self.socket.send(self.wrap_result(
|
||||
{
|
||||
'type': 'test',
|
||||
'data': 'test ok.'
|
||||
}
|
||||
))
|
||||
elif function == 'exit':
|
||||
socket.send(self.wrap_result({'type': 'exit_rsp', 'data': 'ok'}))
|
||||
break
|
||||
else:
|
||||
socket.send(self.wrap_result({'type': 'rsp', 'data': 'no function'}))
|
||||
raise ValueError('Unknown function {}'.format(function))
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
raise
|
||||
print('Server stopped.')
|
||||
|
||||
|
||||
class ModelClient:
|
||||
|
||||
def __init__(self, port=5555, port_sub=5556):
|
||||
self.port = port
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.REQ)
|
||||
self.socket.connect("tcp://localhost:{}".format(self.port))
|
||||
self.socket_sub = self.context.socket(zmq.SUB)
|
||||
self.topic = b'10001'
|
||||
self.socket_sub.setsockopt(zmq.SUBSCRIBE, self.topic)
|
||||
self.socket_sub.connect("tcp://localhost:{}".format(port_sub))
|
||||
self.callback_func = None
|
||||
|
||||
def wrap_result(self, result):
|
||||
with BytesIO() as bio:
|
||||
torch.save(result, bio)
|
||||
return bio.getvalue()
|
||||
|
||||
def unwrap_result(self, result):
|
||||
with BytesIO(result) as bio:
|
||||
return torch.load(bio, map_location='cuda')
|
||||
|
||||
def recieve_thread(self):
|
||||
while True:
|
||||
message = self.socket_sub.recv()
|
||||
message = message[len(self.topic):]
|
||||
message = self.unwrap_result(message)
|
||||
if message['type'] == 'generate':
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(message['data'], is_end=False)
|
||||
elif message['type'] == 'generate_end':
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(None, is_end=True)
|
||||
break
|
||||
else:
|
||||
print(message)
|
||||
break
|
||||
print('receive completed.')
|
||||
|
||||
def start_recieving(self):
|
||||
t = threading.Thread(target=self.recieve_thread)
|
||||
t.setDaemon(True)
|
||||
t.start()
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
data = {
|
||||
'function': 'generate',
|
||||
'args': args,
|
||||
'kwargs': kwargs
|
||||
}
|
||||
self.socket.send(self.wrap_result(data))
|
||||
result = self.socket.recv()
|
||||
return result
|
||||
|
||||
def stop(self):
|
||||
data = {
|
||||
'function': 'stop_generate'
|
||||
}
|
||||
self.socket.send(self.wrap_result(data))
|
||||
result = self.socket.recv()
|
||||
return result
|
||||
|
||||
def test(self):
|
||||
data = {
|
||||
'function': 'test'
|
||||
}
|
||||
self.socket.send(self.wrap_result(data))
|
||||
result = self.socket.recv()
|
||||
return result
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from server import ModelClient
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
def load_model_llama(*args, **kwargs):
|
||||
config_path = '../llama-13b-4bit/'
|
||||
tokenizer = LlamaTokenizer.from_pretrained(config_path)
|
||||
tokenizer.truncation_side = 'left'
|
||||
model = ModelClient(port=5555, port_sub=5556)
|
||||
return model, tokenizer
|
||||
|
||||
patch_encode_func = True
|
||||
|
||||
# Monkey Patch
|
||||
from modules import models
|
||||
from modules import shared
|
||||
models.load_model = load_model_llama
|
||||
shared.args.model = 'llama-13b-4bit'
|
||||
shared.settings['name1'] = 'You'
|
||||
shared.settings['name2'] = 'Assistant'
|
||||
shared.settings['chat_prompt_size_max'] = 2048
|
||||
shared.settings['chat_prompt_size'] = 2048
|
||||
|
||||
if patch_encode_func:
|
||||
from modules import text_generation
|
||||
text_generation.encode_old = text_generation.encode
|
||||
def encode_patched(*args, **kwargs):
|
||||
input_ids = text_generation.encode_old(*args, **kwargs)
|
||||
if input_ids[0,0] == 0:
|
||||
input_ids = input_ids[:, 1:]
|
||||
return input_ids
|
||||
text_generation.encode = encode_patched
|
||||
print('Encode Function Patched.')
|
||||
|
||||
print('Monkey Patch Completed.')
|
||||
|
||||
# Apply Generate Monkey Patch
|
||||
import generate_monkey_patch
|
||||
|
|
@ -0,0 +1,213 @@
|
|||
import modules.text_generation
|
||||
from modules.text_generation import *
|
||||
from modules.callbacks import _SentinelTokenStoppingCriteria
|
||||
|
||||
def generate_reply_patched(question, state, eos_token=None, stopping_strings=[]):
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
print("No model is loaded! Select one in the Model tab.")
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
return
|
||||
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
shared.stop_everything = False
|
||||
generate_params = get_generate_params(state)
|
||||
t0 = time.time()
|
||||
|
||||
# Preparing the input
|
||||
original_question = question
|
||||
if not shared.is_chat():
|
||||
question = apply_extensions('input', question)
|
||||
|
||||
# If the model is not on transformers, handle it separately and end this
|
||||
# function call earlier.
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(encode(original_question)[0])
|
||||
new_tokens = len(encode(output)[0]) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
# Encode the input
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
|
||||
|
||||
# Find the eos tokens
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
if eos_token is not None:
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
|
||||
# Create the StoppingCriteriaList with the stopping strings
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
if type(st) is list and len(st) > 0:
|
||||
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
||||
break
|
||||
|
||||
# Update generate_params with the eos token and the stopping strings
|
||||
if shared.args.flexgen:
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
else:
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
if shared.args.no_stream:
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
if cuda:
|
||||
output = output.cuda()
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the reply 1 token at a time.
|
||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||
elif not shared.args.flexgen:
|
||||
|
||||
# def generate_with_callback(callback=None, **kwargs):
|
||||
# kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||
# clear_torch_cache()
|
||||
# with torch.no_grad():
|
||||
# shared.model.generate(**kwargs)
|
||||
|
||||
# def generate_with_streaming(**kwargs):
|
||||
# return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||
|
||||
# if not shared.is_chat():
|
||||
# yield formatted_outputs(original_question, shared.model_name)
|
||||
|
||||
# with generate_with_streaming(**generate_params) as generator:
|
||||
# for output in generator:
|
||||
# if shared.soft_prompt:
|
||||
# output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
# new_tokens = len(output) - len(input_ids[0])
|
||||
# reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
# if not shared.is_chat():
|
||||
# reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
# if output[-1] in eos_token_ids:
|
||||
# break
|
||||
|
||||
# yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
from queue import Queue
|
||||
queue = Queue()
|
||||
def callback_func(x, is_end=False):
|
||||
if not is_end:
|
||||
queue.put(x)
|
||||
else:
|
||||
queue.put(None)
|
||||
|
||||
# remove stopping_criteria
|
||||
generate_params.pop('stopping_criteria')
|
||||
|
||||
shared.model.callback_func = callback_func
|
||||
shared.model.generate(**generate_params)
|
||||
shared.model.start_recieving()
|
||||
|
||||
token_count = 0
|
||||
while True:
|
||||
reply = queue.get()
|
||||
if reply is None:
|
||||
break
|
||||
token_count += 1
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
generate_params.update({'inputs': input_ids})
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
try:
|
||||
shared.model.stop()
|
||||
except:
|
||||
pass
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = token_count
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
modules.text_generation.generate_reply_old = modules.text_generation.generate_reply
|
||||
modules.text_generation.generate_reply = generate_reply_patched
|
||||
print('Generate Patch Applied')
|
||||
Loading…
Reference in New Issue