add server

This commit is contained in:
John Smith 2023-04-26 12:50:36 +08:00
parent 633c28fd25
commit 1abdc99675
6 changed files with 542 additions and 0 deletions

View File

@ -5,6 +5,7 @@ sentencepiece
safetensors safetensors
einops einops
colorama colorama
pyzmq
git+https://github.com/huggingface/peft.git@70af02a2bca5a63921790036b2c9430edf4037e2 git+https://github.com/huggingface/peft.git@70af02a2bca5a63921790036b2c9430edf4037e2
git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/transformers.git
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit

26
scripts/run_server.py Normal file
View File

@ -0,0 +1,26 @@
from server import ModelServer
import argparse
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--config_path', type=str, required=True)
arg_parser.add_argument('--model_path', type=str, required=True)
arg_parser.add_argument('--lora_path', type=str, default=None)
arg_parser.add_argument('--groupsize', type=int, default=-1)
arg_parser.add_argument('--v1', action='store_true')
arg_parser.add_argument('--quant_attn', action='store_true')
arg_parser.add_argument('--port', type=int, default=5555)
arg_parser.add_argument('--pub_port', type=int, default=5556)
args = arg_parser.parse_args()
server = ModelServer(
config_path=args.config_path,
model_path=args.model_path,
lora_path=args.lora_path,
groupsize=args.groupsize,
is_v1_model=args.v1,
quant_attn=args.quant_attn,
port=args.port,
pub_port=args.pub_port)
server.run()

1
server/__init__.py Normal file
View File

@ -0,0 +1 @@
from .server import ModelClient, ModelServer

264
server/server.py Normal file
View File

@ -0,0 +1,264 @@
from .. import autograd_4bit
import time
import torch
from ..autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
from alpaca_lora_4bit.model_attn_mlp_patch import make_quant_attn, make_fused_mlp, inject_lora_layers
import zmq
from transformers import StoppingCriteria, StoppingCriteriaList
from io import BytesIO
import gc
import threading
def decode(output_ids, tokenizer, skip_special_tokens=True):
if skip_special_tokens:
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
return reply
else:
return tokenizer.decode(output_ids, skip_special_tokens=False)
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
# Copy from text-generation-webui/modules/callbacks.py
class Stream(StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class ModelServer:
def __init__(self, config_path, model_path, lora_path=None, groupsize=128, is_v1_model=False, quant_attn=False, port=5555, pub_port=5556):
self.config_path = config_path
self.model_path = model_path
self.lora_path = lora_path
self.groupsize = groupsize
self.is_v1_model = is_v1_model
self.quant_attn = quant_attn
self.port = port
self.model = None
self.tokenizer = None
self.is_generating = False
self.socket = None
self.socket_pub = None
self.pub_port = pub_port
self.topic = b'10001'
def load_model(self):
print("Loading {} ...".format(self.model_path))
t0 = time.time()
model, tokenizer = load_llama_model_4bit_low_ram(self.config_path, self.model_path, groupsize=self.groupsize, is_v1_model=self.is_v1_model)
if not self.quant_attn and self.lora_path is not None:
from peft import PeftModel
from ..monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
replace_peft_model_with_int4_lora_model()
model = PeftModel.from_pretrained(model, self.lora_path, device_map={'': 0}, torch_dtype=torch.float16)
print('{} Lora Applied.'.format(self.lora_path))
print('Apply half ...')
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
if m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
torch.cuda.empty_cache()
print('Total {:.2f} GiB VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
if not self.quant_attn and self.lora_path is not None:
from ..amp_wrapper import AMPWrapper
wrapper = AMPWrapper(model)
wrapper.apply_generate()
print('AMP applied.')
if self.quant_attn:
make_quant_attn(model, is_v1_model=self.is_v1_model)
make_fused_mlp(model, is_v1_model=self.is_v1_model)
print('Quantized attention applied.')
if self.lora_path is not None:
inject_lora_layers(model, self.lora_path, device='cuda', torch_dtype=torch.float16)
self.model, self.tokenizer = model, tokenizer
print("Loaded in {:.2f} seconds.".format(time.time() - t0))
def wrap_result(self, result):
with BytesIO() as bio:
torch.save(result, bio)
return bio.getvalue()
def unwrap_result(self, result):
with BytesIO(result) as bio:
return torch.load(bio, map_location='cuda')
def send_generate_end_flag(self):
data = {
'type': 'generate_end'
}
self.socket_pub.send(self.topic + self.wrap_result(data))
def generate_thread(self, *args, **kwargs):
clear_torch_cache()
self.is_generating = True
try:
self.model.generate(*args, **kwargs)
except ValueError:
pass
finally:
self.is_generating = False
self.send_generate_end_flag()
clear_torch_cache()
def stop_generate(self):
self.is_generating = False
def run(self):
self.load_model()
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:{}".format(self.port))
self.socket = socket
context_pub = zmq.Context()
socket_pub = context_pub.socket(zmq.PUB)
socket_pub.bind("tcp://*:{}".format(self.pub_port))
self.socket_pub = socket_pub
print('Server started at port {} and {}.'.format(self.port, self.pub_port))
'''
Message Format:
{'function': 'generate',
'args': ...,
'kwargs': ...}
'''
while True:
try:
# Wait for next request from client
message = socket.recv()
message = self.unwrap_result(message)
function = message['function']
if function == 'generate':
if not self.is_generating:
self.is_generating = True
args = message['args']
kwargs = message['kwargs']
input_ids = kwargs['inputs']
def func(x):
if not self.is_generating:
raise ValueError
new_tokens = len(x) - len(input_ids[0])
result = decode(x[-new_tokens:], self.tokenizer, True)
data = {
'type': 'generate',
'data': result
}
socket_pub.send(self.topic + self.wrap_result(data))
kwargs['stopping_criteria'] = StoppingCriteriaList([Stream(callback_func=func)])
t = threading.Thread(target=self.generate_thread, args=args, kwargs=kwargs)
t.setDaemon(True)
t.start()
socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'ok'}))
else:
print('Already generating.')
socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'already generating'}))
elif function == 'stop_generate':
self.stop_generate()
socket.send(self.wrap_result({'type': 'stop_generate_rsp', 'data': 'ok'}))
elif function == 'test':
print('test ok.')
self.socket.send(self.wrap_result(
{
'type': 'test',
'data': 'test ok.'
}
))
elif function == 'exit':
socket.send(self.wrap_result({'type': 'exit_rsp', 'data': 'ok'}))
break
else:
socket.send(self.wrap_result({'type': 'rsp', 'data': 'no function'}))
raise ValueError('Unknown function {}'.format(function))
except Exception as e:
print(str(e))
raise
print('Server stopped.')
class ModelClient:
def __init__(self, port=5555, port_sub=5556):
self.port = port
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect("tcp://localhost:{}".format(self.port))
self.socket_sub = self.context.socket(zmq.SUB)
self.topic = b'10001'
self.socket_sub.setsockopt(zmq.SUBSCRIBE, self.topic)
self.socket_sub.connect("tcp://localhost:{}".format(port_sub))
self.callback_func = None
def wrap_result(self, result):
with BytesIO() as bio:
torch.save(result, bio)
return bio.getvalue()
def unwrap_result(self, result):
with BytesIO(result) as bio:
return torch.load(bio, map_location='cuda')
def recieve_thread(self):
while True:
message = self.socket_sub.recv()
message = message[len(self.topic):]
message = self.unwrap_result(message)
if message['type'] == 'generate':
if self.callback_func is not None:
self.callback_func(message['data'], is_end=False)
elif message['type'] == 'generate_end':
if self.callback_func is not None:
self.callback_func(None, is_end=True)
break
else:
print(message)
break
print('receive completed.')
def start_recieving(self):
t = threading.Thread(target=self.recieve_thread)
t.setDaemon(True)
t.start()
def generate(self, *args, **kwargs):
data = {
'function': 'generate',
'args': args,
'kwargs': kwargs
}
self.socket.send(self.wrap_result(data))
result = self.socket.recv()
return result
def stop(self):
data = {
'function': 'stop_generate'
}
self.socket.send(self.wrap_result(data))
result = self.socket.recv()
return result
def test(self):
data = {
'function': 'test'
}
self.socket.send(self.wrap_result(data))
result = self.socket.recv()
return result

View File

@ -0,0 +1,37 @@
from server import ModelClient
from transformers import LlamaTokenizer
def load_model_llama(*args, **kwargs):
config_path = '../llama-13b-4bit/'
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
model = ModelClient(port=5555, port_sub=5556)
return model, tokenizer
patch_encode_func = True
# Monkey Patch
from modules import models
from modules import shared
models.load_model = load_model_llama
shared.args.model = 'llama-13b-4bit'
shared.settings['name1'] = 'You'
shared.settings['name2'] = 'Assistant'
shared.settings['chat_prompt_size_max'] = 2048
shared.settings['chat_prompt_size'] = 2048
if patch_encode_func:
from modules import text_generation
text_generation.encode_old = text_generation.encode
def encode_patched(*args, **kwargs):
input_ids = text_generation.encode_old(*args, **kwargs)
if input_ids[0,0] == 0:
input_ids = input_ids[:, 1:]
return input_ids
text_generation.encode = encode_patched
print('Encode Function Patched.')
print('Monkey Patch Completed.')
# Apply Generate Monkey Patch
import generate_monkey_patch

View File

@ -0,0 +1,213 @@
import modules.text_generation
from modules.text_generation import *
from modules.callbacks import _SentinelTokenStoppingCriteria
def generate_reply_patched(question, state, eos_token=None, stopping_strings=[]):
if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.")
yield formatted_outputs(question, shared.model_name)
return
clear_torch_cache()
seed = set_manual_seed(state['seed'])
shared.stop_everything = False
generate_params = get_generate_params(state)
t0 = time.time()
# Preparing the input
original_question = question
if not shared.is_chat():
question = apply_extensions('input', question)
# If the model is not on transformers, handle it separately and end this
# function call earlier.
if shared.model_type in ['rwkv', 'llamacpp']:
if shared.args.verbose:
print(f'\n\n{question}\n--------------------\n')
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
else:
if not shared.is_chat():
yield formatted_outputs(question, shared.model_name)
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
if shared.args.verbose:
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
# Find the eos tokens
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
if type(st) is list and len(st) > 0:
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
break
# Update generate_params with the eos token and the stopping strings
if shared.args.flexgen:
generate_params['stop'] = eos_token_ids[-1]
else:
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list
# Add the encoded tokens to generate_params
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
original_input_ids = input_ids
generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
try:
# Generate the entire reply at once.
if shared.args.no_stream:
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if cuda:
output = output.cuda()
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.
elif not shared.args.flexgen:
# def generate_with_callback(callback=None, **kwargs):
# kwargs['stopping_criteria'].append(Stream(callback_func=callback))
# clear_torch_cache()
# with torch.no_grad():
# shared.model.generate(**kwargs)
# def generate_with_streaming(**kwargs):
# return Iteratorize(generate_with_callback, kwargs, callback=None)
# if not shared.is_chat():
# yield formatted_outputs(original_question, shared.model_name)
# with generate_with_streaming(**generate_params) as generator:
# for output in generator:
# if shared.soft_prompt:
# output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
# new_tokens = len(output) - len(input_ids[0])
# reply = decode(output[-new_tokens:], state['skip_special_tokens'])
# if not shared.is_chat():
# reply = original_question + apply_extensions('output', reply)
# if output[-1] in eos_token_ids:
# break
# yield formatted_outputs(reply, shared.model_name)
from queue import Queue
queue = Queue()
def callback_func(x, is_end=False):
if not is_end:
queue.put(x)
else:
queue.put(None)
# remove stopping_criteria
generate_params.pop('stopping_criteria')
shared.model.callback_func = callback_func
shared.model.generate(**generate_params)
shared.model.start_recieving()
token_count = 0
while True:
reply = queue.get()
if reply is None:
break
token_count += 1
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(state['max_new_tokens'] // 8 + 1):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
generate_params.update({'inputs': input_ids})
yield formatted_outputs(reply, shared.model_name)
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
try:
shared.model.stop()
except:
pass
original_tokens = len(original_input_ids[0])
new_tokens = token_count
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
modules.text_generation.generate_reply_old = modules.text_generation.generate_reply
modules.text_generation.generate_reply = generate_reply_patched
print('Generate Patch Applied')