240 lines
8.2 KiB
Python
Executable File
240 lines
8.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import os
|
|
import mimetypes
|
|
import readline
|
|
|
|
from argparse import ArgumentParser
|
|
# from langchain import hub
|
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
from langchain.retrievers.multi_query import MultiQueryRetriever
|
|
# from langchain_community.chat_message_histories import SQLChatMessageHistory
|
|
from langchain_community.document_loaders import TextLoader, WebBaseLoader #, PyPDFLoader
|
|
from langchain_pymupdf4llm import PyMuPDF4LLMLoader
|
|
# from langchain_core import vectorstores
|
|
# from langchain_core.documents import Document
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
# from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
|
|
from langchain_core.runnables import RunnableConfig #, RunnablePassthrough
|
|
# from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
from langchain_core.vectorstores import InMemoryVectorStore
|
|
# from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
from langgraph.graph import START, StateGraph
|
|
from langgraph.graph.message import add_messages
|
|
from sys import stderr
|
|
from term_color_md import render as md_render
|
|
from typing import NotRequired, Sequence
|
|
from typing_extensions import Annotated, TypedDict
|
|
from urllib.parse import urlparse
|
|
from termcolor import colored
|
|
|
|
def main():
|
|
#
|
|
# Readline settings
|
|
#
|
|
readline.parse_and_bind('set editing-mode vi')
|
|
#
|
|
# Parse Arguments
|
|
#
|
|
parser = ArgumentParser()
|
|
parser.add_argument("-v", help="increase output verbosity", action="store_true")
|
|
parser.add_argument(
|
|
"-m",
|
|
type=str,
|
|
help="select language model to use",
|
|
default="gpt-oss"
|
|
)
|
|
parser.add_argument(
|
|
"-s",
|
|
help="don't split documents",
|
|
action="store_true"
|
|
)
|
|
args, paths = parser.parse_known_args()
|
|
|
|
#
|
|
# load LLM
|
|
#
|
|
# llm = ChatOpenAI(model=args.m)
|
|
llm = ChatOllama(model=args.m)
|
|
if args.v: print(">>> Loaded LLM: %s" % llm, file=stderr)
|
|
|
|
#
|
|
# load documents
|
|
#
|
|
|
|
splitter_func = lambda docs: docs
|
|
if not args.s:
|
|
splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=200
|
|
)
|
|
splitter_func = lambda docs: splitter.split_documents(docs)
|
|
|
|
|
|
if args.s: pdf_mode = 'single'
|
|
else: pdf_mode = 'page'
|
|
loaders = {
|
|
"text": lambda file: splitter_func(TextLoader(file).load()),
|
|
"application/pdf": lambda file: PyMuPDF4LLMLoader(file, mode=pdf_mode).load(),
|
|
# "application/pdf": lambda file: PyPDFLoader(file).load(),
|
|
"url": lambda file: splitter_func(WebBaseLoader(file).load()),
|
|
}
|
|
|
|
# docs = PyPDFLoader(paths[0]).load()
|
|
docs = []
|
|
|
|
for path in paths:
|
|
# check if url:
|
|
if urlparse(path).scheme in ("http", "https"):
|
|
if args.v: print(">>> Loading %s as %s" % (path, "url"), file=stderr)
|
|
docs.extend(loaders["url"](path))
|
|
|
|
# check if file exists:
|
|
elif not os.path.exists(path):
|
|
raise FileNotFoundError("%s not found" % path)
|
|
|
|
# detect filetype
|
|
else:
|
|
mimetype, _ = mimetypes.guess_type(path)
|
|
if (mimetype or "").startswith("text/"):
|
|
mimetype = "text"
|
|
|
|
if mimetype not in loaders:
|
|
raise ValueError("Unsupported file type: %s" % mimetype)
|
|
else:
|
|
if args.v: print(">>> Loading %s as %s" % (path, mimetype), file=stderr)
|
|
docs.extend(loaders[mimetype](path))
|
|
|
|
|
|
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai))
|
|
|
|
vectorstore = InMemoryVectorStore(
|
|
embedding=OllamaEmbeddings(model='nomic-embed-text')
|
|
)
|
|
vectorstore.add_documents(docs)
|
|
if args.v: print(">>> Vectorized %d chunks" % len(docs), file=stderr)
|
|
|
|
simple_retriever = vectorstore.as_retriever()
|
|
retriever = MultiQueryRetriever.from_llm(
|
|
retriever=simple_retriever,
|
|
llm=llm
|
|
)
|
|
if args.v: print(">>> Created retriever", file=stderr)
|
|
|
|
#
|
|
# History Prompt
|
|
#
|
|
contextualize_q_system_prompt = (
|
|
"Given a chat history and the latest user question "
|
|
"which might reference context in the chat history, "
|
|
"formulate a standalone question which can be understood "
|
|
"without the chat history. Do NOT answer the question, "
|
|
"just reformulate it if needed and otherwise return it as is."
|
|
)
|
|
|
|
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", contextualize_q_system_prompt),
|
|
MessagesPlaceholder("chat_history"),
|
|
("human", "{input}"),
|
|
]
|
|
)
|
|
|
|
history_aware_retriever = create_history_aware_retriever(
|
|
llm, retriever, contextualize_q_prompt
|
|
)
|
|
if args.v: print(">>> Created history-aware retriever", file=stderr)
|
|
|
|
#
|
|
# Prompt
|
|
#
|
|
system_prompt = (
|
|
"You are an assistant for question-answering tasks. "
|
|
"Use the following pieces of retrieved context to answer "
|
|
"the question. If you don't know the answer, say that you "
|
|
"don't know. Answer as detailed and easy to understand as possible."
|
|
"\n\n"
|
|
"{context}"
|
|
)
|
|
|
|
qa_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", system_prompt),
|
|
MessagesPlaceholder("chat_history"),
|
|
("human", "{input}"),
|
|
]
|
|
)
|
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
|
rag_chain = create_retrieval_chain(
|
|
history_aware_retriever,
|
|
question_answer_chain
|
|
)
|
|
if args.v: print(">>> Created RAG chain", file=stderr)
|
|
|
|
#
|
|
# Memory
|
|
#
|
|
|
|
# We then define a simple node that runs the `rag_chain`.
|
|
# The `return` values of the node update the graph state, so here we just
|
|
# update the chat history with the input message and response.
|
|
def call_model(state: State):
|
|
response = rag_chain.invoke(state)
|
|
return {
|
|
"chat_history": [
|
|
HumanMessage(state["input"]),
|
|
AIMessage(response["answer"]),
|
|
],
|
|
"context": response["context"],
|
|
"answer": response["answer"],
|
|
}
|
|
|
|
# Our graph consists only of one node:
|
|
workflow = StateGraph(state_schema=State)
|
|
workflow.add_edge(START, "model")
|
|
workflow.add_node("model", call_model)
|
|
|
|
# Finally, we compile the graph with a checkpointer object.
|
|
# This persists the state, in this case in memory.
|
|
memory = MemorySaver()
|
|
app = workflow.compile(checkpointer=memory)
|
|
|
|
if args.v: print(">>> Created app memory\n", file=stderr)
|
|
|
|
#
|
|
# Chat
|
|
#
|
|
config: RunnableConfig = {"configurable": {"thread_id": "abc123"}}
|
|
|
|
while True:
|
|
try:
|
|
question = input(colored("Q:", "yellow", attrs=["reverse"]) + " ")
|
|
except EOFError:
|
|
print()
|
|
break
|
|
|
|
print(colored("A:", "green", attrs=["reverse"]), md_render(app.invoke({"input": question},
|
|
config=config)["answer"]), end="\n\n")
|
|
|
|
# We define a dict representing the state of the application.
|
|
# This state has the same input and output keys as `rag_chain`.
|
|
class State(TypedDict):
|
|
input: str
|
|
chat_history: NotRequired[Annotated[Sequence[BaseMessage], add_messages]]
|
|
context: NotRequired[str]
|
|
answer: NotRequired[str]
|
|
|
|
if __name__ == "__main__":
|
|
main()
|