286 lines
11 KiB
Python
286 lines
11 KiB
Python
import time
|
|
import torch
|
|
from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
|
from model_attn_mlp_patch import make_quant_attn, make_fused_mlp, inject_lora_layers
|
|
import zmq
|
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
from io import BytesIO
|
|
import gc
|
|
import threading
|
|
|
|
|
|
def decode(output_ids, tokenizer, skip_special_tokens=True):
|
|
if skip_special_tokens:
|
|
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
|
|
reply = reply.replace(r'<|endoftext|>', '')
|
|
return reply
|
|
else:
|
|
return tokenizer.decode(output_ids, skip_special_tokens=False)
|
|
|
|
|
|
def clear_torch_cache():
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
|
class _SentinelTokenStoppingCriteria(StoppingCriteria):
|
|
|
|
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
|
StoppingCriteria.__init__(self)
|
|
self.sentinel_token_ids = sentinel_token_ids
|
|
self.starting_idx = starting_idx
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
|
for sample in input_ids:
|
|
trimmed_sample = sample[self.starting_idx:]
|
|
|
|
for i in range(len(self.sentinel_token_ids)):
|
|
# Can't unfold, output is still too tiny. Skip.
|
|
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
|
continue
|
|
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
|
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
|
return True
|
|
return False
|
|
|
|
|
|
# Copy from text-generation-webui/modules/callbacks.py
|
|
class Stream(StoppingCriteria):
|
|
def __init__(self, callback_func=None):
|
|
self.callback_func = callback_func
|
|
|
|
def __call__(self, input_ids, scores) -> bool:
|
|
if self.callback_func is not None:
|
|
self.callback_func(input_ids[0])
|
|
return False
|
|
|
|
|
|
class ModelServer:
|
|
|
|
def __init__(self, config_path, model_path, lora_path=None, groupsize=128, is_v1_model=False, quant_attn=False, port=5555, pub_port=5556):
|
|
self.config_path = config_path
|
|
self.model_path = model_path
|
|
self.lora_path = lora_path
|
|
self.groupsize = groupsize
|
|
self.is_v1_model = is_v1_model
|
|
self.quant_attn = quant_attn
|
|
self.port = port
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.is_generating = False
|
|
self.socket = None
|
|
self.socket_pub = None
|
|
self.pub_port = pub_port
|
|
self.topic = b'10001'
|
|
|
|
def load_model(self):
|
|
print("Loading {} ...".format(self.model_path))
|
|
t0 = time.time()
|
|
model, tokenizer = load_llama_model_4bit_low_ram(self.config_path, self.model_path, groupsize=self.groupsize, is_v1_model=self.is_v1_model)
|
|
|
|
if not self.quant_attn and self.lora_path is not None:
|
|
from peft import PeftModel
|
|
from ..monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
|
|
replace_peft_model_with_int4_lora_model()
|
|
model = PeftModel.from_pretrained(model, self.lora_path, device_map={'': 0}, torch_dtype=torch.float16)
|
|
print('{} Lora Applied.'.format(self.lora_path))
|
|
|
|
print('Apply half ...')
|
|
model.half()
|
|
for n, m in model.named_modules():
|
|
if isinstance(m, Autograd4bitQuantLinear):
|
|
if m.is_v1_model:
|
|
m.zeros = m.zeros.half()
|
|
m.scales = m.scales.half()
|
|
m.bias = m.bias.half()
|
|
torch.cuda.empty_cache()
|
|
print('Total {:.2f} GiB VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
|
|
|
|
if not self.quant_attn and self.lora_path is not None:
|
|
from ..amp_wrapper import AMPWrapper
|
|
wrapper = AMPWrapper(model)
|
|
wrapper.apply_generate()
|
|
print('AMP applied.')
|
|
|
|
if self.quant_attn:
|
|
make_quant_attn(model, is_v1_model=self.is_v1_model)
|
|
make_fused_mlp(model, is_v1_model=self.is_v1_model)
|
|
print('Quantized attention applied.')
|
|
|
|
if self.lora_path is not None:
|
|
inject_lora_layers(model, self.lora_path, device='cuda', dtype=torch.float16)
|
|
|
|
self.model, self.tokenizer = model, tokenizer
|
|
print("Loaded in {:.2f} seconds.".format(time.time() - t0))
|
|
|
|
def wrap_result(self, result):
|
|
with BytesIO() as bio:
|
|
torch.save(result, bio)
|
|
return bio.getvalue()
|
|
|
|
def unwrap_result(self, result):
|
|
with BytesIO(result) as bio:
|
|
return torch.load(bio, map_location='cuda')
|
|
|
|
def send_generate_end_flag(self):
|
|
data = {
|
|
'type': 'generate_end'
|
|
}
|
|
self.socket_pub.send(self.topic + self.wrap_result(data))
|
|
|
|
def generate_thread(self, *args, **kwargs):
|
|
clear_torch_cache()
|
|
self.is_generating = True
|
|
try:
|
|
self.model.generate(*args, **kwargs)
|
|
except ValueError:
|
|
pass
|
|
finally:
|
|
self.is_generating = False
|
|
self.send_generate_end_flag()
|
|
clear_torch_cache()
|
|
|
|
def stop_generate(self):
|
|
self.is_generating = False
|
|
|
|
def run(self):
|
|
self.load_model()
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.REP)
|
|
socket.bind("tcp://*:{}".format(self.port))
|
|
self.socket = socket
|
|
context_pub = zmq.Context()
|
|
socket_pub = context_pub.socket(zmq.PUB)
|
|
socket_pub.bind("tcp://*:{}".format(self.pub_port))
|
|
self.socket_pub = socket_pub
|
|
print('Server started at port {} and {}.'.format(self.port, self.pub_port))
|
|
'''
|
|
Message Format:
|
|
{'function': 'generate',
|
|
'args': ...,
|
|
'kwargs': ...}
|
|
'''
|
|
while True:
|
|
try:
|
|
# Wait for next request from client
|
|
message = socket.recv()
|
|
message = self.unwrap_result(message)
|
|
function = message['function']
|
|
if function == 'generate':
|
|
if not self.is_generating:
|
|
self.is_generating = True
|
|
args = message['args']
|
|
kwargs = message['kwargs']
|
|
input_ids = kwargs['inputs']
|
|
def func(x):
|
|
if not self.is_generating:
|
|
raise ValueError
|
|
new_tokens = len(x) - len(input_ids[0])
|
|
result = decode(x[-new_tokens:], self.tokenizer, True)
|
|
data = {
|
|
'type': 'generate',
|
|
'data': result
|
|
}
|
|
socket_pub.send(self.topic + self.wrap_result(data))
|
|
kwargs['stopping_criteria'] = StoppingCriteriaList([Stream(callback_func=func)])
|
|
t = threading.Thread(target=self.generate_thread, args=args, kwargs=kwargs)
|
|
t.setDaemon(True)
|
|
t.start()
|
|
socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'ok'}))
|
|
else:
|
|
print('Already generating.')
|
|
socket.send(self.wrap_result({'type': 'generate_rsp', 'data': 'already generating'}))
|
|
elif function == 'stop_generate':
|
|
self.stop_generate()
|
|
socket.send(self.wrap_result({'type': 'stop_generate_rsp', 'data': 'ok'}))
|
|
elif function == 'test':
|
|
print('test ok.')
|
|
self.socket.send(self.wrap_result(
|
|
{
|
|
'type': 'test',
|
|
'data': 'test ok.'
|
|
}
|
|
))
|
|
elif function == 'exit':
|
|
socket.send(self.wrap_result({'type': 'exit_rsp', 'data': 'ok'}))
|
|
break
|
|
else:
|
|
socket.send(self.wrap_result({'type': 'rsp', 'data': 'no function'}))
|
|
raise ValueError('Unknown function {}'.format(function))
|
|
except Exception as e:
|
|
print(str(e))
|
|
raise
|
|
print('Server stopped.')
|
|
|
|
|
|
class ModelClient:
|
|
|
|
def __init__(self, port=5555, port_sub=5556):
|
|
self.port = port
|
|
self.context = zmq.Context()
|
|
self.socket = self.context.socket(zmq.REQ)
|
|
self.socket.connect("tcp://localhost:{}".format(self.port))
|
|
self.socket_sub = self.context.socket(zmq.SUB)
|
|
self.topic = b'10001'
|
|
self.socket_sub.setsockopt(zmq.SUBSCRIBE, self.topic)
|
|
self.socket_sub.connect("tcp://localhost:{}".format(port_sub))
|
|
self.callback_func = None
|
|
|
|
def wrap_result(self, result):
|
|
with BytesIO() as bio:
|
|
torch.save(result, bio)
|
|
return bio.getvalue()
|
|
|
|
def unwrap_result(self, result):
|
|
with BytesIO(result) as bio:
|
|
return torch.load(bio, map_location='cuda')
|
|
|
|
def recieve_thread(self):
|
|
while True:
|
|
message = self.socket_sub.recv()
|
|
message = message[len(self.topic):]
|
|
message = self.unwrap_result(message)
|
|
if message['type'] == 'generate':
|
|
if self.callback_func is not None:
|
|
self.callback_func(message['data'], is_end=False)
|
|
elif message['type'] == 'generate_end':
|
|
if self.callback_func is not None:
|
|
self.callback_func(None, is_end=True)
|
|
break
|
|
else:
|
|
print(message)
|
|
break
|
|
print('receive completed.')
|
|
|
|
def start_recieving(self):
|
|
t = threading.Thread(target=self.recieve_thread)
|
|
t.setDaemon(True)
|
|
t.start()
|
|
|
|
def generate(self, *args, **kwargs):
|
|
data = {
|
|
'function': 'generate',
|
|
'args': args,
|
|
'kwargs': kwargs
|
|
}
|
|
self.socket.send(self.wrap_result(data))
|
|
result = self.socket.recv()
|
|
return result
|
|
|
|
def stop(self):
|
|
data = {
|
|
'function': 'stop_generate'
|
|
}
|
|
self.socket.send(self.wrap_result(data))
|
|
result = self.socket.recv()
|
|
return result
|
|
|
|
def test(self):
|
|
data = {
|
|
'function': 'test'
|
|
}
|
|
self.socket.send(self.wrap_result(data))
|
|
result = self.socket.recv()
|
|
return result
|