added splitting options

This commit is contained in:
Ian Griffin 2025-11-10 13:31:25 +08:00
parent cf4fdb6fb2
commit 91241a597b
1 changed files with 47 additions and 40 deletions

View File

@ -12,20 +12,20 @@ import re
import readline import readline
from argparse import ArgumentParser from argparse import ArgumentParser
from langchain import hub # from langchain import hub
from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.chat_message_histories import SQLChatMessageHistory # from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_community.document_loaders import TextLoader, WebBaseLoader #, PyPDFLoader from langchain_community.document_loaders import TextLoader, WebBaseLoader #, PyPDFLoader
from langchain_pymupdf4llm import PyMuPDF4LLMLoader from langchain_pymupdf4llm import PyMuPDF4LLMLoader
from langchain_core import vectorstores # from langchain_core import vectorstores
from langchain_core.documents import Document # from langchain_core.documents import Document
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser # from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables import RunnableConfig #, RunnablePassthrough
from langchain_core.runnables.history import RunnableWithMessageHistory # from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.vectorstores import InMemoryVectorStore from langchain_core.vectorstores import InMemoryVectorStore
# from langchain_openai import OpenAIEmbeddings, ChatOpenAI # from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_ollama import OllamaEmbeddings, ChatOllama from langchain_ollama import OllamaEmbeddings, ChatOllama
@ -35,7 +35,7 @@ from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages from langgraph.graph.message import add_messages
from sys import stderr from sys import stderr
from termcolor import colored from termcolor import colored
from typing import Sequence from typing import NotRequired, Sequence
from typing_extensions import Annotated, TypedDict from typing_extensions import Annotated, TypedDict
from urllib.parse import urlparse from urllib.parse import urlparse
from termcolor import colored from termcolor import colored
@ -56,6 +56,11 @@ def main():
help="select language model to use", help="select language model to use",
default="gpt-oss" default="gpt-oss"
) )
parser.add_argument(
"-s",
help="don't split documents",
action="store_true"
)
args, paths = parser.parse_known_args() args, paths = parser.parse_known_args()
# #
@ -69,11 +74,22 @@ def main():
# load documents # load documents
# #
splitter_func = lambda docs: docs
if not args.s:
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
splitter_func = lambda docs: splitter.split_documents(docs)
if args.s: pdf_mode = 'single'
else: pdf_mode = 'page'
loaders = { loaders = {
"text": lambda file: TextLoader(file).load(), "text": lambda file: splitter_func(TextLoader(file).load()),
"application/pdf": lambda file: PyMuPDF4LLMLoader(file).load(), "application/pdf": lambda file: PyMuPDF4LLMLoader(file, mode=pdf_mode).load(),
# "application/pdf": lambda file: PyPDFLoader(file).load(), # "application/pdf": lambda file: PyPDFLoader(file).load(),
"url": lambda file: WebBaseLoader(file).load(), "url": lambda file: splitter_func(WebBaseLoader(file).load()),
} }
# docs = PyPDFLoader(paths[0]).load() # docs = PyPDFLoader(paths[0]).load()
@ -92,7 +108,7 @@ def main():
# detect filetype # detect filetype
else: else:
mimetype, _ = mimetypes.guess_type(path) mimetype, _ = mimetypes.guess_type(path)
if mimetype.startswith("text/"): if (mimetype or "").startswith("text/"):
mimetype = "text" mimetype = "text"
if mimetype not in loaders: if mimetype not in loaders:
@ -101,19 +117,14 @@ def main():
if args.v: print(">>> Loading %s as %s" % (path, mimetype), file=stderr) if args.v: print(">>> Loading %s as %s" % (path, mimetype), file=stderr)
docs.extend(loaders[mimetype](path)) docs.extend(loaders[mimetype](path))
splits = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
).split_documents(docs)
if args.v: print(">>> Split %d documents into %d chunks" % (len(docs), len(splits)), file=stderr)
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai)) # vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai))
vectorstore = InMemoryVectorStore( vectorstore = InMemoryVectorStore(
embedding=OllamaEmbeddings(model='nomic-embed-text') embedding=OllamaEmbeddings(model='nomic-embed-text')
) )
vectorstore.add_documents(splits) vectorstore.add_documents(docs)
if args.v: print(">>> Vectorized %d chunks" % len(splits), file=stderr) if args.v: print(">>> Vectorized %d chunks" % len(docs), file=stderr)
simple_retriever = vectorstore.as_retriever() simple_retriever = vectorstore.as_retriever()
retriever = MultiQueryRetriever.from_llm( retriever = MultiQueryRetriever.from_llm(
@ -145,7 +156,6 @@ def main():
history_aware_retriever = create_history_aware_retriever( history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt llm, retriever, contextualize_q_prompt
) )
if args.v: print(">>> Created history-aware retriever", file=stderr) if args.v: print(">>> Created history-aware retriever", file=stderr)
# #
@ -207,7 +217,7 @@ def main():
# #
# Chat # Chat
# #
config = {"configurable": {"thread_id": "abc123"}} config: RunnableConfig = {"configurable": {"thread_id": "abc123"}}
while True: while True:
try: try:
@ -223,9 +233,9 @@ def main():
# This state has the same input and output keys as `rag_chain`. # This state has the same input and output keys as `rag_chain`.
class State(TypedDict): class State(TypedDict):
input: str input: str
chat_history: Annotated[Sequence[BaseMessage], add_messages] chat_history: NotRequired[Annotated[Sequence[BaseMessage], add_messages]]
context: str context: NotRequired[str]
answer: str answer: NotRequired[str]
def parse_markdown(text): def parse_markdown(text):
lines = text.splitlines() lines = text.splitlines()
@ -243,26 +253,23 @@ def parse_markdown(text):
# Check for headers # Check for headers
if line.startswith("# "): if line.startswith("# "):
level = len(line) - len(line.lstrip("#")) header_text = line.lstrip("#").strip()
header_text = line.strip() #.lstrip("#").strip()
formatted_text += colored(header_text, "blue", attrs=["bold", "underline"]) + "\n" formatted_text += colored(header_text, "blue", attrs=["bold", "underline"]) + "\n"
continue continue
if line.startswith("## "): if line.startswith("## "):
level = len(line) - len(line.lstrip("#")) header_text = line.lstrip("#").strip()
header_text = line.strip() #.lstrip("#").strip()
formatted_text += colored(header_text, "blue", attrs=["bold"]) + "\n" formatted_text += colored(header_text, "blue", attrs=["bold"]) + "\n"
continue continue
if line.startswith("### "): if line.startswith("### "):
level = len(line) - len(line.lstrip("#")) header_text = line.lstrip("#").strip()
header_text = line.strip() #.lstrip("#").strip()
formatted_text += colored(header_text, "cyan", attrs=["bold"]) + "\n" formatted_text += colored(header_text, "cyan", attrs=["bold"]) + "\n"
continue continue
# Check for blockquotes # Check for blockquotes
if line.startswith(">"): if line.startswith(">"):
quote_text = line.strip() #.lstrip(">").strip() quote_text = line.lstrip(">").strip()
formatted_text += colored(quote_text, "yellow") + "\n" formatted_text += colored(quote_text, "yellow") + "\n"
continue continue