WIKI RAG QA

wikiQA是一个问答系统,用于从confluence中获取文档,生成文档向量,然后用llama3-chinese-alpaca回答问题。

wiki_qa.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.chat_models import ChatOllama
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate

from langchain_community.vectorstores import Chroma

from langchain_community.document_loaders import ConfluenceLoader
import requests
import sys
import os
import json
import pickle
import pathlib
import concurrent.futures
from requests.cookies import cookiejar_from_dict

class WiKi_QA:
def __init__(self):
self.wiki_url = "xxxxx"
self.embedding = OllamaEmbeddings(model='smartcreation/dmeta-embedding-zh:f16')

def gen_vectors(self, jsessionid: str = None, space_key: str = None, page_ids: str = None):
print('gen_vectors', jsessionid, space_key, page_ids)
if not page_ids and space_key == 'healthy' and pathlib.Path(os.path.split(os.path.realpath(__file__))[0] + '/persist').exists():
print('use presist')
self.vectordb = Chroma(persist_directory=os.path.split(os.path.realpath(__file__))[0] + '/persist', embedding_function=self.embedding)
else:
s = requests.Session()
s.cookies = cookiejar_from_dict({
'JSESSIONID': jsessionid
})
# s.proxies = {'http': 'http://127.0.0.1:9999', 'https': 'http://127.0.0.1:9999'}

loader = ConfluenceLoader(
url=self.wiki_url,
session= s,
cloud=False,
# space_key="healthy",
space_key=None if space_key == '' else space_key,
page_ids= page_ids.split(',') if page_ids else None,
limit=1,
max_pages=99999999
)
documents = loader.load()
print(len(documents))
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=100, add_start_index=True
)
texts = text_splitter.split_documents(documents)
print(len(texts))
self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding)


def fetch_health_docs(self, jsessionid: str):
with open(sys.path[0] + '/page_ids.json', 'r', encoding='utf-8') as f:
page_ids = json.loads(f.read())

with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
future_to_url = {executor.submit(self.fetch_docs, jsessionid, page_id): page_id for page_id in page_ids}
for future in concurrent.futures.as_completed(future_to_url):
url = future_to_url[future]
try:
data = future.result()
except Exception as exc:
print('%r generated an exception: %s' % (url, exc))

def fetch_docs(self, jsessionid: str, page_id: str):
print('start', page_id)
s = requests.Session()
s.cookies = cookiejar_from_dict({
'JSESSIONID': jsessionid
})
loader = ConfluenceLoader(
url=self.wiki_url,
session= s,
cloud=False,
page_ids=[page_id,],
limit=1,
max_pages=99999999
)
documents = loader.lazy_load()
index = 0
for one in documents:
print(one)
with open(sys.path[0] + f'/doc/{page_id}_{index}.pkl', 'wb') as f:
pickle.dump(one, f)
index = index + 1

def gen_healthy_vectors(self):
documents = []
for plk in pathlib.Path(sys.path[0] + '/doc').iterdir():
with open(sys.path[0] + f'/doc/{plk.name}', 'rb') as f:
documents.append(pickle.load(f))
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=100, add_start_index=True
)
texts = text_splitter.split_documents(documents)
print(len(texts))
self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding, persist_directory=sys.path[0] + '/persist')

def retreival_qa_chain(self):
self.retriever = self.vectordb.as_retriever(search_kwargs={"k":8})
self.llm = ChatOllama(model='lgkt/llama3-chinese-alpaca', temperature=0.)

system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
"{context}"
)

prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)

question_answer_chain = create_stuff_documents_chain(self.llm, prompt)
self.chain = create_retrieval_chain(self.retriever, question_answer_chain)

def answer_confluence(self,question:str) ->str:
answer = self.chain.invoke({"input": question})
return answer

app.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import streamlit as st
from wiki_qa import WiKi_QA

st.set_page_config(
page_title='Wiki Q&A',
page_icon='📚📚📚📚',
layout='wide',
initial_sidebar_state='auto',
)
if "config" not in st.session_state:
st.session_state["config"] = {}
if "qa" not in st.session_state:
st.session_state["qa"] = None

@st.cache_resource
def load_confluence(config):
qa = WiKi_QA()
qa.gen_vectors(config['jsession_id'], config['space_key'], config['page_ids'])
qa.retreival_qa_chain()
return qa

with st.sidebar.form(key ='Form1'):
st.markdown('## 使用配置')
jsession_id = st.text_input(label="jsessionid",
help="F12获取jsessionid")
space_key = st.text_input(label="空间",
help="wiki的空间",
value="healthy")
page_ids = st.text_input(label="页面id",
help="多个页面id用逗号分隔")
submitted1 = st.form_submit_button(label='Submit')

if submitted1:
st.session_state["config"] = {
"jsession_id": jsession_id if jsession_id != "" else None,
"page_ids": page_ids if page_ids != "" else None,
"space_key": space_key,
}
with st.spinner(text="..."):
config = st.session_state["config"]
st.session_state["config"] = config
st.session_state["qa"] = load_confluence(st.session_state["config"])
st.write("Ingested")


st.title("WIKI Q&A")

question = st.text_input('问一个问题', "商品中心有哪两部分组成?")

if st.button('获取答案', key='button2'):
with st.spinner(text="..."):
qa = st.session_state.get("qa")
if qa is not None:
result = qa.answer_confluence(question)
st.write(result)
else:
st.write("请先设置")

llama3回答都是英文,因而改用lgkt/llama3-chinese-alpaca,但中文返回结果感觉比较单薄。

问题:

  1. 为什么要用RecursiveCharacterTextSplitter?与其他的splitter有什么区别?还有哪些splitter?chunk_sizechunk_overlap如何设置?
  2. retriever返回的结果有时明显不准确,如何提高准确率?
  3. embedding的model如何选择?对中文有影响吗?

Build a Retrieval Augmented Generation (RAG) App
Building a Confluence Q&A App with LangChain and ChatGPT
RAG行业交流中发现的一些问题和改进方法