150 lines
4.6 KiB
Python
Executable File
150 lines
4.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import sys
|
|
import srv_secrets
|
|
import logging
|
|
|
|
from argparse import ArgumentParser
|
|
from langchain import OpenAI
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.vectorstores import Chroma
|
|
from langchain.memory import ConversationBufferMemory
|
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
|
from langchain.retrievers.multi_query import MultiQueryRetriever
|
|
|
|
from flask import Flask, request, abort
|
|
from linebot.v3 import (
|
|
WebhookHandler
|
|
)
|
|
from linebot.v3.exceptions import (
|
|
InvalidSignatureError
|
|
)
|
|
from linebot.v3.webhooks import (
|
|
MessageEvent,
|
|
TextMessageContent,
|
|
)
|
|
from linebot.v3.messaging import (
|
|
Configuration,
|
|
ApiClient,
|
|
MessagingApi,
|
|
ReplyMessageRequest,
|
|
PushMessageRequest,
|
|
TextMessage
|
|
)
|
|
|
|
app = Flask(__name__)
|
|
|
|
# logging
|
|
log_file_name = "./linebot.message.log"
|
|
|
|
# initiate log file
|
|
log_file = open(log_file_name, "w")
|
|
log_file.write("")
|
|
log_file.close()
|
|
|
|
|
|
|
|
# User Sessions
|
|
user_sessions = {}
|
|
|
|
# init vectorstore embedding
|
|
os.environ["OPENAI_API_KEY"] = srv_secrets.openai_key
|
|
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
|
vectorstore = Chroma(persist_directory=srv_secrets.chroma_db_dir, embedding_function=OpenAIEmbeddings())
|
|
|
|
# Setup Logging
|
|
logging.basicConfig()
|
|
logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
|
|
|
|
# Retrieve
|
|
#retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vectorstore.as_retriever(), llm=ChatOpenAI(temperature=0))
|
|
|
|
# Converstational QA
|
|
retriever=vectorstore.as_retriever()
|
|
#memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
|
#chat = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,memory=memory)
|
|
|
|
# get channel_secret and channel_access_token from your environment variable
|
|
# channel_secret = os.getenv('LINE_CHANNEL_SECRET', None)
|
|
# channel_access_token = os.getenv('LINE_CHANNEL_ACCESS_TOKEN', None)
|
|
# if channel_secret is None:
|
|
# print('Specify LINE_CHANNEL_SECRET as environment variable.')
|
|
# sys.exit(1)
|
|
# if channel_access_token is None:
|
|
# print('Specify LINE_CHANNEL_ACCESS_TOKEN as environment variable.')
|
|
# sys.exit(1)
|
|
|
|
handler = WebhookHandler(srv_secrets.channel_secret)
|
|
|
|
configuration = Configuration(
|
|
access_token=srv_secrets.channel_access_token
|
|
)
|
|
|
|
@app.route("/", methods=['POST'])
|
|
def callback():
|
|
# get X-Line-Signature header value
|
|
signature = request.headers['X-Line-Signature']
|
|
|
|
# get request body as text
|
|
body = request.get_data(as_text=True)
|
|
app.logger.info("Request body: " + body)
|
|
|
|
# handle webhook body
|
|
try:
|
|
print("signature correct")
|
|
handler.handle(body, signature)
|
|
except InvalidSignatureError:
|
|
abort(400)
|
|
|
|
return 'OK'
|
|
|
|
|
|
@handler.add(MessageEvent, message=TextMessageContent)
|
|
def message_text(event):
|
|
|
|
log_file = open(log_file_name, "a")
|
|
log_file.write("user_id: " + event.source.user_id + ", time: " + str(event.timestamp) + ", message:" + event.message.text)
|
|
# User Session
|
|
# create session if none exist
|
|
if event.source.user_id not in user_sessions.keys():
|
|
user_sessions[event.source.user_id] = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True))
|
|
|
|
# unique_docs = retriever_from_llm.get_relevant_documents(query=event.message.text)
|
|
with ApiClient(configuration) as api_client:
|
|
try:
|
|
answer = str(user_sessions[event.source.user_id]({"question": event.message.text})['answer'])
|
|
except:
|
|
answer = "message failed to process, please try again"
|
|
|
|
line_bot_api = MessagingApi(api_client)
|
|
line_bot_api.reply_message_with_http_info(
|
|
ReplyMessageRequest(
|
|
reply_token=event.reply_token,
|
|
messages=[TextMessage(text=answer)]
|
|
))
|
|
|
|
log_file.write("bot_answer: " + answer + "\n")
|
|
log_file.close()
|
|
|
|
# with ApiClient(configuration) as api_client:
|
|
# line_bot_api = MessagingApi(api_client)
|
|
# line_bot_api.reply_message_with_http_info(
|
|
# ReplyMessageRequest(
|
|
# reply_token=event.reply_token,
|
|
# messages=[TextMessage(text="Message received.\nProcessing... please wait")]
|
|
# ))
|
|
|
|
|
|
# with ApiClient(configuration) as api_client:
|
|
# line_bot_api = MessagingApi(api_client)
|
|
#
|
|
# line_bot_api.push_message_with_http_info(
|
|
# PushMessageRequest(to=event.source.user_id,messages=[TextMessage(text=answer)]))
|
|
|
|
# main is here
|
|
if __name__ == "__main__":
|
|
|
|
# start web server
|
|
app.run(port=srv_secrets.srv_port) |