added splitting options
This commit is contained in:
parent
cf4fdb6fb2
commit
91241a597b
87
ragger.py
87
ragger.py
|
|
@ -12,20 +12,20 @@ import re
|
||||||
import readline
|
import readline
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from langchain import hub
|
# from langchain import hub
|
||||||
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
||||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||||
from langchain_community.chat_message_histories import SQLChatMessageHistory
|
# from langchain_community.chat_message_histories import SQLChatMessageHistory
|
||||||
from langchain_community.document_loaders import TextLoader, WebBaseLoader #, PyPDFLoader
|
from langchain_community.document_loaders import TextLoader, WebBaseLoader #, PyPDFLoader
|
||||||
from langchain_pymupdf4llm import PyMuPDF4LLMLoader
|
from langchain_pymupdf4llm import PyMuPDF4LLMLoader
|
||||||
from langchain_core import vectorstores
|
# from langchain_core import vectorstores
|
||||||
from langchain_core.documents import Document
|
# from langchain_core.documents import Document
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
# from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
|
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
|
||||||
from langchain_core.runnables import RunnablePassthrough
|
from langchain_core.runnables import RunnableConfig #, RunnablePassthrough
|
||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
# from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
from langchain_core.vectorstores import InMemoryVectorStore
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
# from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
# from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||||
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
||||||
|
|
@ -35,7 +35,7 @@ from langgraph.graph import START, StateGraph
|
||||||
from langgraph.graph.message import add_messages
|
from langgraph.graph.message import add_messages
|
||||||
from sys import stderr
|
from sys import stderr
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from typing import Sequence
|
from typing import NotRequired, Sequence
|
||||||
from typing_extensions import Annotated, TypedDict
|
from typing_extensions import Annotated, TypedDict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
@ -56,6 +56,11 @@ def main():
|
||||||
help="select language model to use",
|
help="select language model to use",
|
||||||
default="gpt-oss"
|
default="gpt-oss"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
help="don't split documents",
|
||||||
|
action="store_true"
|
||||||
|
)
|
||||||
args, paths = parser.parse_known_args()
|
args, paths = parser.parse_known_args()
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
@ -68,12 +73,23 @@ def main():
|
||||||
#
|
#
|
||||||
# load documents
|
# load documents
|
||||||
#
|
#
|
||||||
|
|
||||||
|
splitter_func = lambda docs: docs
|
||||||
|
if not args.s:
|
||||||
|
splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200
|
||||||
|
)
|
||||||
|
splitter_func = lambda docs: splitter.split_documents(docs)
|
||||||
|
|
||||||
|
|
||||||
|
if args.s: pdf_mode = 'single'
|
||||||
|
else: pdf_mode = 'page'
|
||||||
loaders = {
|
loaders = {
|
||||||
"text": lambda file: TextLoader(file).load(),
|
"text": lambda file: splitter_func(TextLoader(file).load()),
|
||||||
"application/pdf": lambda file: PyMuPDF4LLMLoader(file).load(),
|
"application/pdf": lambda file: PyMuPDF4LLMLoader(file, mode=pdf_mode).load(),
|
||||||
# "application/pdf": lambda file: PyPDFLoader(file).load(),
|
# "application/pdf": lambda file: PyPDFLoader(file).load(),
|
||||||
"url": lambda file: WebBaseLoader(file).load(),
|
"url": lambda file: splitter_func(WebBaseLoader(file).load()),
|
||||||
}
|
}
|
||||||
|
|
||||||
# docs = PyPDFLoader(paths[0]).load()
|
# docs = PyPDFLoader(paths[0]).load()
|
||||||
|
|
@ -92,7 +108,7 @@ def main():
|
||||||
# detect filetype
|
# detect filetype
|
||||||
else:
|
else:
|
||||||
mimetype, _ = mimetypes.guess_type(path)
|
mimetype, _ = mimetypes.guess_type(path)
|
||||||
if mimetype.startswith("text/"):
|
if (mimetype or "").startswith("text/"):
|
||||||
mimetype = "text"
|
mimetype = "text"
|
||||||
|
|
||||||
if mimetype not in loaders:
|
if mimetype not in loaders:
|
||||||
|
|
@ -101,19 +117,14 @@ def main():
|
||||||
if args.v: print(">>> Loading %s as %s" % (path, mimetype), file=stderr)
|
if args.v: print(">>> Loading %s as %s" % (path, mimetype), file=stderr)
|
||||||
docs.extend(loaders[mimetype](path))
|
docs.extend(loaders[mimetype](path))
|
||||||
|
|
||||||
splits = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=1000,
|
|
||||||
chunk_overlap=200
|
|
||||||
).split_documents(docs)
|
|
||||||
if args.v: print(">>> Split %d documents into %d chunks" % (len(docs), len(splits)), file=stderr)
|
|
||||||
|
|
||||||
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai))
|
|
||||||
|
|
||||||
|
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai))
|
||||||
|
|
||||||
vectorstore = InMemoryVectorStore(
|
vectorstore = InMemoryVectorStore(
|
||||||
embedding=OllamaEmbeddings(model='nomic-embed-text')
|
embedding=OllamaEmbeddings(model='nomic-embed-text')
|
||||||
)
|
)
|
||||||
vectorstore.add_documents(splits)
|
vectorstore.add_documents(docs)
|
||||||
if args.v: print(">>> Vectorized %d chunks" % len(splits), file=stderr)
|
if args.v: print(">>> Vectorized %d chunks" % len(docs), file=stderr)
|
||||||
|
|
||||||
simple_retriever = vectorstore.as_retriever()
|
simple_retriever = vectorstore.as_retriever()
|
||||||
retriever = MultiQueryRetriever.from_llm(
|
retriever = MultiQueryRetriever.from_llm(
|
||||||
|
|
@ -140,12 +151,11 @@ def main():
|
||||||
("human", "{input}"),
|
("human", "{input}"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
history_aware_retriever = create_history_aware_retriever(
|
history_aware_retriever = create_history_aware_retriever(
|
||||||
llm, retriever, contextualize_q_prompt
|
llm, retriever, contextualize_q_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.v: print(">>> Created history-aware retriever", file=stderr)
|
if args.v: print(">>> Created history-aware retriever", file=stderr)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
@ -159,7 +169,7 @@ def main():
|
||||||
"\n\n"
|
"\n\n"
|
||||||
"{context}"
|
"{context}"
|
||||||
)
|
)
|
||||||
|
|
||||||
qa_prompt = ChatPromptTemplate.from_messages(
|
qa_prompt = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
("system", system_prompt),
|
("system", system_prompt),
|
||||||
|
|
@ -203,11 +213,11 @@ def main():
|
||||||
app = workflow.compile(checkpointer=memory)
|
app = workflow.compile(checkpointer=memory)
|
||||||
|
|
||||||
if args.v: print(">>> Created app memory\n", file=stderr)
|
if args.v: print(">>> Created app memory\n", file=stderr)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Chat
|
# Chat
|
||||||
#
|
#
|
||||||
config = {"configurable": {"thread_id": "abc123"}}
|
config: RunnableConfig = {"configurable": {"thread_id": "abc123"}}
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
@ -223,9 +233,9 @@ def main():
|
||||||
# This state has the same input and output keys as `rag_chain`.
|
# This state has the same input and output keys as `rag_chain`.
|
||||||
class State(TypedDict):
|
class State(TypedDict):
|
||||||
input: str
|
input: str
|
||||||
chat_history: Annotated[Sequence[BaseMessage], add_messages]
|
chat_history: NotRequired[Annotated[Sequence[BaseMessage], add_messages]]
|
||||||
context: str
|
context: NotRequired[str]
|
||||||
answer: str
|
answer: NotRequired[str]
|
||||||
|
|
||||||
def parse_markdown(text):
|
def parse_markdown(text):
|
||||||
lines = text.splitlines()
|
lines = text.splitlines()
|
||||||
|
|
@ -243,26 +253,23 @@ def parse_markdown(text):
|
||||||
|
|
||||||
# Check for headers
|
# Check for headers
|
||||||
if line.startswith("# "):
|
if line.startswith("# "):
|
||||||
level = len(line) - len(line.lstrip("#"))
|
header_text = line.lstrip("#").strip()
|
||||||
header_text = line.strip() #.lstrip("#").strip()
|
|
||||||
formatted_text += colored(header_text, "blue", attrs=["bold", "underline"]) + "\n"
|
formatted_text += colored(header_text, "blue", attrs=["bold", "underline"]) + "\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if line.startswith("## "):
|
if line.startswith("## "):
|
||||||
level = len(line) - len(line.lstrip("#"))
|
header_text = line.lstrip("#").strip()
|
||||||
header_text = line.strip() #.lstrip("#").strip()
|
|
||||||
formatted_text += colored(header_text, "blue", attrs=["bold"]) + "\n"
|
formatted_text += colored(header_text, "blue", attrs=["bold"]) + "\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if line.startswith("### "):
|
if line.startswith("### "):
|
||||||
level = len(line) - len(line.lstrip("#"))
|
header_text = line.lstrip("#").strip()
|
||||||
header_text = line.strip() #.lstrip("#").strip()
|
|
||||||
formatted_text += colored(header_text, "cyan", attrs=["bold"]) + "\n"
|
formatted_text += colored(header_text, "cyan", attrs=["bold"]) + "\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check for blockquotes
|
# Check for blockquotes
|
||||||
if line.startswith(">"):
|
if line.startswith(">"):
|
||||||
quote_text = line.strip() #.lstrip(">").strip()
|
quote_text = line.lstrip(">").strip()
|
||||||
formatted_text += colored(quote_text, "yellow") + "\n"
|
formatted_text += colored(quote_text, "yellow") + "\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue