Compare commits

...

3 Commits

Author SHA1 Message Date
Ian Griffin afc1e1aad1 prompt fix 2023-08-09 19:28:43 +07:00
Ian Griffin bc1c99c41c added prompt template for qa chain 2023-08-09 19:22:47 +07:00
Ian Griffin 5e588b17f0 finalize logging code 2023-08-05 10:56:12 +07:00
1 changed files with 12 additions and 4 deletions

View File

@ -13,6 +13,7 @@ from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.prompts import PromptTemplate
from flask import Flask, request, abort
from linebot.v3 import (
@ -44,15 +45,22 @@ log_file = open(log_file_name, "w")
log_file.write("")
log_file.close()
# User Sessions
user_sessions = {}
# init vectorstore embedding
qa_prompt_template = """
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
always add this exact keyphrase "answer_not_found" to the end of the text if you don't know the answer.
{context}
Question: {question}
Helpful Answer:"""
os.environ["OPENAI_API_KEY"] = srv_secrets.openai_key
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
vectorstore = Chroma(persist_directory=srv_secrets.chroma_db_dir, embedding_function=OpenAIEmbeddings())
qa_prompt = PromptTemplate.from_template(qa_prompt_template)
# Setup Logging
logging.basicConfig()
@ -109,7 +117,7 @@ def message_text(event):
# User Session
# create session if none exist
if event.source.user_id not in user_sessions.keys():
user_sessions[event.source.user_id] = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True))
user_sessions[event.source.user_id] = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True, combine_docs_chain_kwargs={"prompt": qa_prompt}))
# unique_docs = retriever_from_llm.get_relevant_documents(query=event.message.text)
with ApiClient(configuration) as api_client:
@ -125,7 +133,7 @@ def message_text(event):
messages=[TextMessage(text=answer)]
))
log_file.write("bot_answer: " + answer + "\n")
log_file.write(", bot_answer: " + answer + "\n")
log_file.close()
# with ApiClient(configuration) as api_client: