From 91241a597bf53bb3568cfca4daa9849658d92fd4 Mon Sep 17 00:00:00 2001 From: Ian Griffin Date: Mon, 10 Nov 2025 13:31:25 +0800 Subject: [PATCH] added splitting options --- ragger.py | 87 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 47 insertions(+), 40 deletions(-) diff --git a/ragger.py b/ragger.py index 701be61..731ef6c 100755 --- a/ragger.py +++ b/ragger.py @@ -12,20 +12,20 @@ import re import readline from argparse import ArgumentParser -from langchain import hub +# from langchain import hub from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.retrievers.multi_query import MultiQueryRetriever -from langchain_community.chat_message_histories import SQLChatMessageHistory +# from langchain_community.chat_message_histories import SQLChatMessageHistory from langchain_community.document_loaders import TextLoader, WebBaseLoader #, PyPDFLoader from langchain_pymupdf4llm import PyMuPDF4LLMLoader -from langchain_core import vectorstores -from langchain_core.documents import Document +# from langchain_core import vectorstores +# from langchain_core.documents import Document from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_core.output_parsers import StrOutputParser +# from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate -from langchain_core.runnables import RunnablePassthrough -from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.runnables import RunnableConfig #, RunnablePassthrough +# from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.vectorstores import InMemoryVectorStore # from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_ollama import OllamaEmbeddings, ChatOllama @@ -35,7 +35,7 @@ from langgraph.graph import START, StateGraph from langgraph.graph.message import add_messages from sys import stderr from termcolor import colored -from typing import Sequence +from typing import NotRequired, Sequence from typing_extensions import Annotated, TypedDict from urllib.parse import urlparse from termcolor import colored @@ -56,6 +56,11 @@ def main(): help="select language model to use", default="gpt-oss" ) + parser.add_argument( + "-s", + help="don't split documents", + action="store_true" + ) args, paths = parser.parse_known_args() # @@ -68,12 +73,23 @@ def main(): # # load documents # - + + splitter_func = lambda docs: docs + if not args.s: + splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200 + ) + splitter_func = lambda docs: splitter.split_documents(docs) + + + if args.s: pdf_mode = 'single' + else: pdf_mode = 'page' loaders = { - "text": lambda file: TextLoader(file).load(), - "application/pdf": lambda file: PyMuPDF4LLMLoader(file).load(), + "text": lambda file: splitter_func(TextLoader(file).load()), + "application/pdf": lambda file: PyMuPDF4LLMLoader(file, mode=pdf_mode).load(), # "application/pdf": lambda file: PyPDFLoader(file).load(), - "url": lambda file: WebBaseLoader(file).load(), + "url": lambda file: splitter_func(WebBaseLoader(file).load()), } # docs = PyPDFLoader(paths[0]).load() @@ -92,7 +108,7 @@ def main(): # detect filetype else: mimetype, _ = mimetypes.guess_type(path) - if mimetype.startswith("text/"): + if (mimetype or "").startswith("text/"): mimetype = "text" if mimetype not in loaders: @@ -101,19 +117,14 @@ def main(): if args.v: print(">>> Loading %s as %s" % (path, mimetype), file=stderr) docs.extend(loaders[mimetype](path)) - splits = RecursiveCharacterTextSplitter( - chunk_size=1000, - chunk_overlap=200 - ).split_documents(docs) - if args.v: print(">>> Split %d documents into %d chunks" % (len(docs), len(splits)), file=stderr) - - # vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai)) + # vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai)) + vectorstore = InMemoryVectorStore( embedding=OllamaEmbeddings(model='nomic-embed-text') ) - vectorstore.add_documents(splits) - if args.v: print(">>> Vectorized %d chunks" % len(splits), file=stderr) + vectorstore.add_documents(docs) + if args.v: print(">>> Vectorized %d chunks" % len(docs), file=stderr) simple_retriever = vectorstore.as_retriever() retriever = MultiQueryRetriever.from_llm( @@ -140,12 +151,11 @@ def main(): ("human", "{input}"), ] ) - - + + history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) - if args.v: print(">>> Created history-aware retriever", file=stderr) # @@ -159,7 +169,7 @@ def main(): "\n\n" "{context}" ) - + qa_prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), @@ -203,11 +213,11 @@ def main(): app = workflow.compile(checkpointer=memory) if args.v: print(">>> Created app memory\n", file=stderr) - + # # Chat # - config = {"configurable": {"thread_id": "abc123"}} + config: RunnableConfig = {"configurable": {"thread_id": "abc123"}} while True: try: @@ -223,9 +233,9 @@ def main(): # This state has the same input and output keys as `rag_chain`. class State(TypedDict): input: str - chat_history: Annotated[Sequence[BaseMessage], add_messages] - context: str - answer: str + chat_history: NotRequired[Annotated[Sequence[BaseMessage], add_messages]] + context: NotRequired[str] + answer: NotRequired[str] def parse_markdown(text): lines = text.splitlines() @@ -243,26 +253,23 @@ def parse_markdown(text): # Check for headers if line.startswith("# "): - level = len(line) - len(line.lstrip("#")) - header_text = line.strip() #.lstrip("#").strip() + header_text = line.lstrip("#").strip() formatted_text += colored(header_text, "blue", attrs=["bold", "underline"]) + "\n" continue - + if line.startswith("## "): - level = len(line) - len(line.lstrip("#")) - header_text = line.strip() #.lstrip("#").strip() + header_text = line.lstrip("#").strip() formatted_text += colored(header_text, "blue", attrs=["bold"]) + "\n" continue - + if line.startswith("### "): - level = len(line) - len(line.lstrip("#")) - header_text = line.strip() #.lstrip("#").strip() + header_text = line.lstrip("#").strip() formatted_text += colored(header_text, "cyan", attrs=["bold"]) + "\n" continue # Check for blockquotes if line.startswith(">"): - quote_text = line.strip() #.lstrip(">").strip() + quote_text = line.lstrip(">").strip() formatted_text += colored(quote_text, "yellow") + "\n" continue