diff --git a/requirements.txt b/requirements.txt index e7b9ed1..c574159 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ sentencepiece safetensors einops colorama +pyzmq git+https://github.com/huggingface/peft.git@70af02a2bca5a63921790036b2c9430edf4037e2 git+https://github.com/huggingface/transformers.git git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit diff --git a/scripts/run_server.py b/scripts/run_server.py new file mode 100644 index 0000000..602a631 --- /dev/null +++ b/scripts/run_server.py @@ -0,0 +1,26 @@ +from server import ModelServer +import argparse + +if __name__ == '__main__': + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('--config_path', type=str, required=True) + arg_parser.add_argument('--model_path', type=str, required=True) + arg_parser.add_argument('--lora_path', type=str, default=None) + arg_parser.add_argument('--groupsize', type=int, default=-1) + arg_parser.add_argument('--v1', action='store_true') + arg_parser.add_argument('--quant_attn', action='store_true') + arg_parser.add_argument('--port', type=int, default=5555) + arg_parser.add_argument('--pub_port', type=int, default=5556) + args = arg_parser.parse_args() + + server = ModelServer( + config_path=args.config_path, + model_path=args.model_path, + lora_path=args.lora_path, + groupsize=args.groupsize, + is_v1_model=args.v1, + quant_attn=args.quant_attn, + port=args.port, + pub_port=args.pub_port) + + server.run() diff --git a/server/__init__.py b/server/__init__.py new file mode 100644 index 0000000..d01cea1 --- /dev/null +++ b/server/__init__.py @@ -0,0 +1 @@ +from .server import ModelClient, ModelServer diff --git a/server/server.py b/server/server.py new file mode 100644 index 0000000..ef3dae8 --- /dev/null +++ b/server/server.py @@ -0,0 +1,264 @@ +from .. import autograd_4bit +import time +import torch +from ..autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear +from alpaca_lora_4bit.model_attn_mlp_patch import make_quant_attn, make_fused_mlp, inject_lora_layers +import zmq +from transformers import StoppingCriteria, StoppingCriteriaList +from io import BytesIO +import gc +import threading + + +def decode(output_ids, tokenizer, skip_special_tokens=True): + if skip_special_tokens: + reply = tokenizer.decode(output_ids, skip_special_tokens=True) + reply = reply.replace(r'<|endoftext|>', '') + return reply + else: + return tokenizer.decode(output_ids, skip_special_tokens=False) + + +def clear_torch_cache(): + gc.collect() + torch.cuda.empty_cache() + + +# Copy from text-generation-webui/modules/callbacks.py +class Stream(StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids, scores) -> bool: + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + + +class ModelServer: + + def __init__(self, config_path, model_path, lora_path=None, groupsize=128, is_v1_model=False, quant_attn=False, port=5555, pub_port=5556): + self.config_path = config_path + self.model_path = model_path + self.lora_path = lora_path + self.groupsize = groupsize + self.is_v1_model = is_v1_model + self.quant_attn = quant_attn + self.port = port + self.model = None + self.tokenizer = None + self.is_generating = False + self.socket = None + self.socket_pub = None + self.pub_port = pub_port + self.topic = b'10001' + + def load_model(self): + print("Loading {} ...".format(self.model_path)) + t0 = time.time() + model, tokenizer = load_llama_model_4bit_low_ram(self.config_path, self.model_path, groupsize=self.groupsize, is_v1_model=self.is_v1_model) + + if not self.quant_attn and self.lora_path is not None: + from peft import PeftModel + from ..monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model + replace_peft_model_with_int4_lora_model() + model = PeftModel.from_pretrained(model, self.lora_path, device_map={'': 0}, torch_dtype=torch.float16) + print('{} Lora Applied.'.format(self.lora_path)) + + print('Apply half ...') + model.half() + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear): + if m.is_v1_model: + m.zeros = m.zeros.half() + m.scales = m.scales.half() + m.bias = m.bias.half() + torch.cuda.empty_cache() + print('Total {:.2f} GiB VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) + + if not self.quant_attn and self.lora_path is not None: + from ..amp_wrapper import AMPWrapper + wrapper = AMPWrapper(model) + wrapper.apply_generate() + print('AMP applied.') + + if self.quant_attn: + make_quant_attn(model, is_v1_model=self.is_v1_model) + make_fused_mlp(model, is_v1_model=self.is_v1_model) + print('Quantized attention applied.') + + if self.lora_path is not None: + inject_lora_layers(model, self.lora_path, device='cuda', torch_dtype=torch.float16) + + self.model, self.tokenizer = model, tokenizer + print("Loaded in {:.2f} seconds.".format(time.time() - t0)) + + def wrap_result(self, result): + with BytesIO() as bio: + torch.save(result, bio) + return bio.getvalue() + + def unwrap_result(self, result): + with BytesIO(result) as bio: + return torch.load(bio, map_location='cuda') + + def send_generate_end_flag(self): + data = { + 'type': 'generate_end' + } + self.socket_pub.send(self.topic + self.wrap_result(data)) + + def generate_thread(self, *args, **kwargs): + clear_torch_cache() + self.is_generating = True + try: + self.model.generate(*args, **kwargs) + except ValueError: + pass + finally: + self.is_generating = False + self.send_generate_end_flag() + clear_torch_cache() + + def stop_generate(self): + self.is_generating = False + + def run(self): + self.load_model() + context = zmq.Context() + socket = context.socket(zmq.REP) + socket.bind("tcp://*:{}".format(self.port)) + self.socket = socket + context_pub = zmq.Context() + socket_pub = context_pub.socket(zmq.PUB) + socket_pub.bind("tcp://*:{}".format(self.pub_port)) + self.socket_pub = socket_pub + print('Server started at port {} and {}.'.format(self.port, self.pub_port)) + ''' + Message Format: + {'function': 'generate', + 'args': ..., + 'kwargs': ...} + ''' + while True: + try: + # Wait for next request from client + message = socket.recv() + message = self.unwrap_result(message) + function = message['function'] + if function == 'generate': + if not self.is_generating: + self.is_generating = True + args = message['args'] + kwargs = message['kwargs'] + input_ids = kwargs['inputs'] + def func(x): + if not self.is_generating: + raise ValueError + new_tokens = len(x) - len(input_ids[0]) + result = decode(x[-new_tokens:], self.tokenizer, True) + data = { + 'type': 'generate', + 'data': result + } + socket_pub.send(self.topic + self.wrap_result(data)) + kwargs['stopping_criteria'] = StoppingCriteriaList([Stream(callback_func=func)]) + t = threading.Thread(target=self.generate_thread, args=args, kwargs=kwargs) + t.setDaemon(True) + t.start() + socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'ok'})) + else: + print('Already generating.') + socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'already generating'})) + elif function == 'stop_generate': + self.stop_generate() + socket.send(self.wrap_result({'type': 'stop_generate_rsp', 'data': 'ok'})) + elif function == 'test': + print('test ok.') + self.socket.send(self.wrap_result( + { + 'type': 'test', + 'data': 'test ok.' + } + )) + elif function == 'exit': + socket.send(self.wrap_result({'type': 'exit_rsp', 'data': 'ok'})) + break + else: + socket.send(self.wrap_result({'type': 'rsp', 'data': 'no function'})) + raise ValueError('Unknown function {}'.format(function)) + except Exception as e: + print(str(e)) + raise + print('Server stopped.') + + +class ModelClient: + + def __init__(self, port=5555, port_sub=5556): + self.port = port + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REQ) + self.socket.connect("tcp://localhost:{}".format(self.port)) + self.socket_sub = self.context.socket(zmq.SUB) + self.topic = b'10001' + self.socket_sub.setsockopt(zmq.SUBSCRIBE, self.topic) + self.socket_sub.connect("tcp://localhost:{}".format(port_sub)) + self.callback_func = None + + def wrap_result(self, result): + with BytesIO() as bio: + torch.save(result, bio) + return bio.getvalue() + + def unwrap_result(self, result): + with BytesIO(result) as bio: + return torch.load(bio, map_location='cuda') + + def recieve_thread(self): + while True: + message = self.socket_sub.recv() + message = message[len(self.topic):] + message = self.unwrap_result(message) + if message['type'] == 'generate': + if self.callback_func is not None: + self.callback_func(message['data'], is_end=False) + elif message['type'] == 'generate_end': + if self.callback_func is not None: + self.callback_func(None, is_end=True) + break + else: + print(message) + break + print('receive completed.') + + def start_recieving(self): + t = threading.Thread(target=self.recieve_thread) + t.setDaemon(True) + t.start() + + def generate(self, *args, **kwargs): + data = { + 'function': 'generate', + 'args': args, + 'kwargs': kwargs + } + self.socket.send(self.wrap_result(data)) + result = self.socket.recv() + return result + + def stop(self): + data = { + 'function': 'stop_generate' + } + self.socket.send(self.wrap_result(data)) + result = self.socket.recv() + return result + + def test(self): + data = { + 'function': 'test' + } + self.socket.send(self.wrap_result(data)) + result = self.socket.recv() + return result diff --git a/text-generation-webui/custom_model_server_monkey_patch.py b/text-generation-webui/custom_model_server_monkey_patch.py new file mode 100644 index 0000000..ea76f8f --- /dev/null +++ b/text-generation-webui/custom_model_server_monkey_patch.py @@ -0,0 +1,37 @@ +from server import ModelClient +from transformers import LlamaTokenizer + +def load_model_llama(*args, **kwargs): + config_path = '../llama-13b-4bit/' + tokenizer = LlamaTokenizer.from_pretrained(config_path) + tokenizer.truncation_side = 'left' + model = ModelClient(port=5555, port_sub=5556) + return model, tokenizer + +patch_encode_func = True + +# Monkey Patch +from modules import models +from modules import shared +models.load_model = load_model_llama +shared.args.model = 'llama-13b-4bit' +shared.settings['name1'] = 'You' +shared.settings['name2'] = 'Assistant' +shared.settings['chat_prompt_size_max'] = 2048 +shared.settings['chat_prompt_size'] = 2048 + +if patch_encode_func: + from modules import text_generation + text_generation.encode_old = text_generation.encode + def encode_patched(*args, **kwargs): + input_ids = text_generation.encode_old(*args, **kwargs) + if input_ids[0,0] == 0: + input_ids = input_ids[:, 1:] + return input_ids + text_generation.encode = encode_patched + print('Encode Function Patched.') + +print('Monkey Patch Completed.') + +# Apply Generate Monkey Patch +import generate_monkey_patch diff --git a/text-generation-webui/generate_monkey_patch.py b/text-generation-webui/generate_monkey_patch.py new file mode 100644 index 0000000..efe52fa --- /dev/null +++ b/text-generation-webui/generate_monkey_patch.py @@ -0,0 +1,213 @@ +import modules.text_generation +from modules.text_generation import * +from modules.callbacks import _SentinelTokenStoppingCriteria + +def generate_reply_patched(question, state, eos_token=None, stopping_strings=[]): + if shared.model_name == 'None' or shared.model is None: + print("No model is loaded! Select one in the Model tab.") + yield formatted_outputs(question, shared.model_name) + return + + clear_torch_cache() + seed = set_manual_seed(state['seed']) + shared.stop_everything = False + generate_params = get_generate_params(state) + t0 = time.time() + + # Preparing the input + original_question = question + if not shared.is_chat(): + question = apply_extensions('input', question) + + # If the model is not on transformers, handle it separately and end this + # function call earlier. + if shared.model_type in ['rwkv', 'llamacpp']: + if shared.args.verbose: + print(f'\n\n{question}\n--------------------\n') + + try: + if shared.args.no_stream: + reply = shared.model.generate(context=question, **generate_params) + output = original_question + reply + if not shared.is_chat(): + reply = original_question + apply_extensions('output', reply) + + yield formatted_outputs(reply, shared.model_name) + else: + if not shared.is_chat(): + yield formatted_outputs(question, shared.model_name) + + for reply in shared.model.generate_with_streaming(context=question, **generate_params): + output = original_question + reply + if not shared.is_chat(): + reply = original_question + apply_extensions('output', reply) + + yield formatted_outputs(reply, shared.model_name) + + except Exception: + traceback.print_exc() + finally: + t1 = time.time() + original_tokens = len(encode(original_question)[0]) + new_tokens = len(encode(output)[0]) - original_tokens + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') + return + + # Encode the input + input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) + output = input_ids[0] + cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) + if shared.args.verbose: + print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n') + + # Find the eos tokens + eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] + if eos_token is not None: + eos_token_ids.append(int(encode(eos_token)[0][-1])) + + # Create the StoppingCriteriaList with the stopping strings + stopping_criteria_list = transformers.StoppingCriteriaList() + for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): + if type(st) is list and len(st) > 0: + sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st] + stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0]))) + break + + # Update generate_params with the eos token and the stopping strings + if shared.args.flexgen: + generate_params['stop'] = eos_token_ids[-1] + else: + generate_params['eos_token_id'] = eos_token_ids + generate_params['stopping_criteria'] = stopping_criteria_list + + # Add the encoded tokens to generate_params + if shared.soft_prompt: + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds) + original_input_ids = input_ids + generate_params.update({'inputs_embeds': inputs_embeds}) + generate_params.update({'inputs': filler_input_ids}) + else: + question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) + original_input_ids = input_ids + generate_params.update({'inputs': input_ids}) + if inputs_embeds is not None: + generate_params.update({'inputs_embeds': inputs_embeds}) + + try: + # Generate the entire reply at once. + if shared.args.no_stream: + with torch.no_grad(): + output = shared.model.generate(**generate_params)[0] + if cuda: + output = output.cuda() + + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + + new_tokens = len(output) - len(input_ids[0]) + reply = decode(output[-new_tokens:], state['skip_special_tokens']) + if not shared.is_chat(): + reply = original_question + apply_extensions('output', reply) + + yield formatted_outputs(reply, shared.model_name) + + # Stream the reply 1 token at a time. + # This is based on the trick of using 'stopping_criteria' to create an iterator. + elif not shared.args.flexgen: + + # def generate_with_callback(callback=None, **kwargs): + # kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + # clear_torch_cache() + # with torch.no_grad(): + # shared.model.generate(**kwargs) + + # def generate_with_streaming(**kwargs): + # return Iteratorize(generate_with_callback, kwargs, callback=None) + + # if not shared.is_chat(): + # yield formatted_outputs(original_question, shared.model_name) + + # with generate_with_streaming(**generate_params) as generator: + # for output in generator: + # if shared.soft_prompt: + # output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + + # new_tokens = len(output) - len(input_ids[0]) + # reply = decode(output[-new_tokens:], state['skip_special_tokens']) + # if not shared.is_chat(): + # reply = original_question + apply_extensions('output', reply) + + # if output[-1] in eos_token_ids: + # break + + # yield formatted_outputs(reply, shared.model_name) + + from queue import Queue + queue = Queue() + def callback_func(x, is_end=False): + if not is_end: + queue.put(x) + else: + queue.put(None) + + # remove stopping_criteria + generate_params.pop('stopping_criteria') + + shared.model.callback_func = callback_func + shared.model.generate(**generate_params) + shared.model.start_recieving() + + token_count = 0 + while True: + reply = queue.get() + if reply is None: + break + token_count += 1 + yield formatted_outputs(reply, shared.model_name) + + # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' + else: + for i in range(state['max_new_tokens'] // 8 + 1): + clear_torch_cache() + with torch.no_grad(): + output = shared.model.generate(**generate_params)[0] + + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + + new_tokens = len(output) - len(original_input_ids[0]) + reply = decode(output[-new_tokens:], state['skip_special_tokens']) + if not shared.is_chat(): + reply = original_question + apply_extensions('output', reply) + + if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): + break + + yield formatted_outputs(reply, shared.model_name) + input_ids = np.reshape(output, (1, output.shape[0])) + if shared.soft_prompt: + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + generate_params.update({'inputs_embeds': inputs_embeds}) + generate_params.update({'inputs': filler_input_ids}) + else: + generate_params.update({'inputs': input_ids}) + + yield formatted_outputs(reply, shared.model_name) + + except Exception: + traceback.print_exc() + finally: + t1 = time.time() + try: + shared.model.stop() + except: + pass + original_tokens = len(original_input_ids[0]) + new_tokens = token_count + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') + return + +modules.text_generation.generate_reply_old = modules.text_generation.generate_reply +modules.text_generation.generate_reply = generate_reply_patched +print('Generate Patch Applied')