line_bot/linebot_app.py

158 lines
5.1 KiB
Python
Executable File

#!/usr/bin/env python3
import os
import sys
import srv_secrets
import logging
from argparse import ArgumentParser
from langchain import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.prompts import PromptTemplate
from flask import Flask, request, abort
from linebot.v3 import (
WebhookHandler
)
from linebot.v3.exceptions import (
InvalidSignatureError
)
from linebot.v3.webhooks import (
MessageEvent,
TextMessageContent,
)
from linebot.v3.messaging import (
Configuration,
ApiClient,
MessagingApi,
ReplyMessageRequest,
PushMessageRequest,
TextMessage
)
app = Flask(__name__)
# logging
log_file_name = "./linebot.message.log"
# initiate log file
log_file = open(log_file_name, "w")
log_file.write("")
log_file.close()
# User Sessions
user_sessions = {}
# init vectorstore embedding
qa_prompt_template = """
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
always add this exact keyphrase "answer_not_found" to the end of the text if you don't know the answer.
{context}
Question: {question}
Helpful Answer:"""
os.environ["OPENAI_API_KEY"] = srv_secrets.openai_key
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
vectorstore = Chroma(persist_directory=srv_secrets.chroma_db_dir, embedding_function=OpenAIEmbeddings())
qa_prompt = PromptTemplate.from_template()
# Setup Logging
logging.basicConfig()
logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
# Retrieve
#retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vectorstore.as_retriever(), llm=ChatOpenAI(temperature=0))
# Converstational QA
retriever=vectorstore.as_retriever()
#memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
#chat = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,memory=memory)
# get channel_secret and channel_access_token from your environment variable
# channel_secret = os.getenv('LINE_CHANNEL_SECRET', None)
# channel_access_token = os.getenv('LINE_CHANNEL_ACCESS_TOKEN', None)
# if channel_secret is None:
# print('Specify LINE_CHANNEL_SECRET as environment variable.')
# sys.exit(1)
# if channel_access_token is None:
# print('Specify LINE_CHANNEL_ACCESS_TOKEN as environment variable.')
# sys.exit(1)
handler = WebhookHandler(srv_secrets.channel_secret)
configuration = Configuration(
access_token=srv_secrets.channel_access_token
)
@app.route("/", methods=['POST'])
def callback():
# get X-Line-Signature header value
signature = request.headers['X-Line-Signature']
# get request body as text
body = request.get_data(as_text=True)
app.logger.info("Request body: " + body)
# handle webhook body
try:
print("signature correct")
handler.handle(body, signature)
except InvalidSignatureError:
abort(400)
return 'OK'
@handler.add(MessageEvent, message=TextMessageContent)
def message_text(event):
log_file = open(log_file_name, "a")
log_file.write("user_id: " + event.source.user_id + ", time: " + str(event.timestamp) + ", message:" + event.message.text)
# User Session
# create session if none exist
if event.source.user_id not in user_sessions.keys():
user_sessions[event.source.user_id] = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True, combine_docs_chain_kwargs={"prompt": qa_prompt}))
# unique_docs = retriever_from_llm.get_relevant_documents(query=event.message.text)
with ApiClient(configuration) as api_client:
try:
answer = str(user_sessions[event.source.user_id]({"question": event.message.text})['answer'])
except:
answer = "message failed to process, please try again"
line_bot_api = MessagingApi(api_client)
line_bot_api.reply_message_with_http_info(
ReplyMessageRequest(
reply_token=event.reply_token,
messages=[TextMessage(text=answer)]
))
log_file.write(", bot_answer: " + answer + "\n")
log_file.close()
# with ApiClient(configuration) as api_client:
# line_bot_api = MessagingApi(api_client)
# line_bot_api.reply_message_with_http_info(
# ReplyMessageRequest(
# reply_token=event.reply_token,
# messages=[TextMessage(text="Message received.\nProcessing... please wait")]
# ))
# with ApiClient(configuration) as api_client:
# line_bot_api = MessagingApi(api_client)
#
# line_bot_api.push_message_with_http_info(
# PushMessageRequest(to=event.source.user_id,messages=[TextMessage(text=answer)]))
# main is here
if __name__ == "__main__":
# start web server
app.run(port=srv_secrets.srv_port)