some updates
This commit is contained in:
parent
51325bde69
commit
cf4fdb6fb2
82
ragger.py
82
ragger.py
|
|
@ -9,6 +9,7 @@
|
||||||
import os
|
import os
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
|
import readline
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from langchain import hub
|
from langchain import hub
|
||||||
|
|
@ -16,7 +17,8 @@ from langchain.chains import create_history_aware_retriever, create_retrieval_ch
|
||||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||||
from langchain_community.chat_message_histories import SQLChatMessageHistory
|
from langchain_community.chat_message_histories import SQLChatMessageHistory
|
||||||
from langchain_community.document_loaders import PyPDFLoader, TextLoader, WebBaseLoader
|
from langchain_community.document_loaders import TextLoader, WebBaseLoader #, PyPDFLoader
|
||||||
|
from langchain_pymupdf4llm import PyMuPDF4LLMLoader
|
||||||
from langchain_core import vectorstores
|
from langchain_core import vectorstores
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
|
@ -25,7 +27,8 @@ from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
|
||||||
from langchain_core.runnables import RunnablePassthrough
|
from langchain_core.runnables import RunnablePassthrough
|
||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
from langchain_core.vectorstores import InMemoryVectorStore
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
# from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||||
|
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.graph import START, StateGraph
|
from langgraph.graph import START, StateGraph
|
||||||
|
|
@ -38,20 +41,29 @@ from urllib.parse import urlparse
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
#
|
||||||
|
# Readline settings
|
||||||
|
#
|
||||||
|
readline.parse_and_bind('set editing-mode vi')
|
||||||
#
|
#
|
||||||
# Parse Arguments
|
# Parse Arguments
|
||||||
#
|
#
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("-v", help="increase output verbosity", action="store_true")
|
parser.add_argument("-v", help="increase output verbosity", action="store_true")
|
||||||
parser.add_argument("-m", type=str, help="select OpenAI model to use", default="gpt-3.5-turbo")
|
parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
type=str,
|
||||||
|
help="select language model to use",
|
||||||
|
default="gpt-oss"
|
||||||
|
)
|
||||||
args, paths = parser.parse_known_args()
|
args, paths = parser.parse_known_args()
|
||||||
|
|
||||||
#
|
#
|
||||||
# load LLM
|
# load LLM
|
||||||
#
|
#
|
||||||
llm = ChatOpenAI(model=args.m)
|
# llm = ChatOpenAI(model=args.m)
|
||||||
if args.v:
|
llm = ChatOllama(model=args.m)
|
||||||
print(">>> Loaded LLM: %s" % llm, file=stderr)
|
if args.v: print(">>> Loaded LLM: %s" % llm, file=stderr)
|
||||||
|
|
||||||
#
|
#
|
||||||
# load documents
|
# load documents
|
||||||
|
|
@ -59,7 +71,8 @@ def main():
|
||||||
|
|
||||||
loaders = {
|
loaders = {
|
||||||
"text": lambda file: TextLoader(file).load(),
|
"text": lambda file: TextLoader(file).load(),
|
||||||
"application/pdf": lambda file: PyPDFLoader(file).load(),
|
"application/pdf": lambda file: PyMuPDF4LLMLoader(file).load(),
|
||||||
|
# "application/pdf": lambda file: PyPDFLoader(file).load(),
|
||||||
"url": lambda file: WebBaseLoader(file).load(),
|
"url": lambda file: WebBaseLoader(file).load(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -69,8 +82,7 @@ def main():
|
||||||
for path in paths:
|
for path in paths:
|
||||||
# check if url:
|
# check if url:
|
||||||
if urlparse(path).scheme in ("http", "https"):
|
if urlparse(path).scheme in ("http", "https"):
|
||||||
if args.v:
|
if args.v: print(">>> Loading %s as %s" % (path, "url"), file=stderr)
|
||||||
print(">>> Loading %s as %s" % (path, "url"), file=stderr)
|
|
||||||
docs.extend(loaders["url"](path))
|
docs.extend(loaders["url"](path))
|
||||||
|
|
||||||
# check if file exists:
|
# check if file exists:
|
||||||
|
|
@ -86,25 +98,29 @@ def main():
|
||||||
if mimetype not in loaders:
|
if mimetype not in loaders:
|
||||||
raise ValueError("Unsupported file type: %s" % mimetype)
|
raise ValueError("Unsupported file type: %s" % mimetype)
|
||||||
else:
|
else:
|
||||||
if args.v:
|
if args.v: print(">>> Loading %s as %s" % (path, mimetype), file=stderr)
|
||||||
print(">>> Loading %s as %s" % (path, mimetype), file=stderr)
|
|
||||||
docs.extend(loaders[mimetype](path))
|
docs.extend(loaders[mimetype](path))
|
||||||
|
|
||||||
splits = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200).split_documents(docs)
|
splits = RecursiveCharacterTextSplitter(
|
||||||
if args.v:
|
chunk_size=1000,
|
||||||
print(">>> Split %d documents into %d chunks" % (len(docs), len(splits)), file=stderr)
|
chunk_overlap=200
|
||||||
|
).split_documents(docs)
|
||||||
|
if args.v: print(">>> Split %d documents into %d chunks" % (len(docs), len(splits)), file=stderr)
|
||||||
|
|
||||||
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai))
|
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(openai_api_key=APIKeys.openai))
|
||||||
|
|
||||||
vectorstore = InMemoryVectorStore(embedding=OpenAIEmbeddings())
|
vectorstore = InMemoryVectorStore(
|
||||||
|
embedding=OllamaEmbeddings(model='nomic-embed-text')
|
||||||
|
)
|
||||||
vectorstore.add_documents(splits)
|
vectorstore.add_documents(splits)
|
||||||
if args.v:
|
if args.v: print(">>> Vectorized %d chunks" % len(splits), file=stderr)
|
||||||
print(">>> Vectorized %d chunks" % len(splits), file=stderr)
|
|
||||||
|
|
||||||
simple_retriever = vectorstore.as_retriever()
|
simple_retriever = vectorstore.as_retriever()
|
||||||
retriever = MultiQueryRetriever.from_llm(retriever=simple_retriever, llm=llm)
|
retriever = MultiQueryRetriever.from_llm(
|
||||||
if args.v:
|
retriever=simple_retriever,
|
||||||
print(">>> Created retriever", file=stderr)
|
llm=llm
|
||||||
|
)
|
||||||
|
if args.v: print(">>> Created retriever", file=stderr)
|
||||||
|
|
||||||
#
|
#
|
||||||
# History Prompt
|
# History Prompt
|
||||||
|
|
@ -130,8 +146,7 @@ def main():
|
||||||
llm, retriever, contextualize_q_prompt
|
llm, retriever, contextualize_q_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.v:
|
if args.v: print(">>> Created history-aware retriever", file=stderr)
|
||||||
print(">>> Created history-aware retriever", file=stderr)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Prompt
|
# Prompt
|
||||||
|
|
@ -153,10 +168,11 @@ def main():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
||||||
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
rag_chain = create_retrieval_chain(
|
||||||
|
history_aware_retriever,
|
||||||
if args.v:
|
question_answer_chain
|
||||||
print(">>> Created RAG chain", file=stderr)
|
)
|
||||||
|
if args.v: print(">>> Created RAG chain", file=stderr)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Memory
|
# Memory
|
||||||
|
|
@ -186,8 +202,8 @@ def main():
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
app = workflow.compile(checkpointer=memory)
|
app = workflow.compile(checkpointer=memory)
|
||||||
|
|
||||||
if args.v:
|
if args.v: print(">>> Created app memory\n", file=stderr)
|
||||||
print(">>> Created app memory\n", file=stderr)
|
|
||||||
#
|
#
|
||||||
# Chat
|
# Chat
|
||||||
#
|
#
|
||||||
|
|
@ -195,7 +211,7 @@ def main():
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
question = input(colored("Q: ", "yellow", attrs=["reverse"]))
|
question = input(colored("Q:", "yellow", attrs=["reverse"]) + " ")
|
||||||
except EOFError:
|
except EOFError:
|
||||||
print()
|
print()
|
||||||
break
|
break
|
||||||
|
|
@ -228,25 +244,25 @@ def parse_markdown(text):
|
||||||
# Check for headers
|
# Check for headers
|
||||||
if line.startswith("# "):
|
if line.startswith("# "):
|
||||||
level = len(line) - len(line.lstrip("#"))
|
level = len(line) - len(line.lstrip("#"))
|
||||||
header_text = line.lstrip("#").strip()
|
header_text = line.strip() #.lstrip("#").strip()
|
||||||
formatted_text += colored(header_text, "blue", attrs=["bold", "underline"]) + "\n"
|
formatted_text += colored(header_text, "blue", attrs=["bold", "underline"]) + "\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if line.startswith("## "):
|
if line.startswith("## "):
|
||||||
level = len(line) - len(line.lstrip("#"))
|
level = len(line) - len(line.lstrip("#"))
|
||||||
header_text = line.lstrip("#").strip()
|
header_text = line.strip() #.lstrip("#").strip()
|
||||||
formatted_text += colored(header_text, "blue", attrs=["bold"]) + "\n"
|
formatted_text += colored(header_text, "blue", attrs=["bold"]) + "\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if line.startswith("### "):
|
if line.startswith("### "):
|
||||||
level = len(line) - len(line.lstrip("#"))
|
level = len(line) - len(line.lstrip("#"))
|
||||||
header_text = line.lstrip("#").strip()
|
header_text = line.strip() #.lstrip("#").strip()
|
||||||
formatted_text += colored(header_text, "cyan", attrs=["bold"]) + "\n"
|
formatted_text += colored(header_text, "cyan", attrs=["bold"]) + "\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check for blockquotes
|
# Check for blockquotes
|
||||||
if line.startswith(">"):
|
if line.startswith(">"):
|
||||||
quote_text = line.lstrip(">").strip()
|
quote_text = line.strip() #.lstrip(">").strip()
|
||||||
formatted_text += colored(quote_text, "yellow") + "\n"
|
formatted_text += colored(quote_text, "yellow") + "\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,11 @@ gradio
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
langchain
|
langchain
|
||||||
langchain-community
|
langchain-community
|
||||||
langchain-openai
|
# langchain-openai
|
||||||
|
langchain-ollama
|
||||||
|
langchain-pymupdf4llm
|
||||||
langgraph
|
langgraph
|
||||||
openai
|
openai
|
||||||
pypdf==5.0.1
|
# pypdf==5.0.1
|
||||||
termcolor
|
termcolor
|
||||||
tiktoken
|
tiktoken
|
||||||
11
todo.txt
11
todo.txt
|
|
@ -1,7 +1,12 @@
|
||||||
async document loading
|
save conversation
|
||||||
|
editable input
|
||||||
toggleable rich text
|
toggleable rich text
|
||||||
initial question argument
|
async document loading
|
||||||
no looping argument
|
recursive directory reading
|
||||||
|
skip files @argument
|
||||||
|
proper markdown rendering
|
||||||
|
initial question @argument
|
||||||
|
no looping @argument
|
||||||
better code structure
|
better code structure
|
||||||
huggingface models availability
|
huggingface models availability
|
||||||
UI
|
UI
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue