1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
| from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.embeddings import OllamaEmbeddings from langchain_community.chat_models import ChatOllama from langchain.chains.retrieval import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import ConfluenceLoader import requests import sys import os import json import pickle import pathlib import concurrent.futures from requests.cookies import cookiejar_from_dict
class WiKi_QA: def __init__(self): self.wiki_url = "xxxxx" self.embedding = OllamaEmbeddings(model='smartcreation/dmeta-embedding-zh:f16')
def gen_vectors(self, jsessionid: str = None, space_key: str = None, page_ids: str = None): print('gen_vectors', jsessionid, space_key, page_ids) if not page_ids and space_key == 'healthy' and pathlib.Path(os.path.split(os.path.realpath(__file__))[0] + '/persist').exists(): print('use presist') self.vectordb = Chroma(persist_directory=os.path.split(os.path.realpath(__file__))[0] + '/persist', embedding_function=self.embedding) else: s = requests.Session() s.cookies = cookiejar_from_dict({ 'JSESSIONID': jsessionid })
loader = ConfluenceLoader( url=self.wiki_url, session= s, cloud=False, space_key=None if space_key == '' else space_key, page_ids= page_ids.split(',') if page_ids else None, limit=1, max_pages=99999999 ) documents = loader.load() print(len(documents)) text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=100, add_start_index=True ) texts = text_splitter.split_documents(documents) print(len(texts)) self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding)
def fetch_health_docs(self, jsessionid: str): with open(sys.path[0] + '/page_ids.json', 'r', encoding='utf-8') as f: page_ids = json.loads(f.read())
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: future_to_url = {executor.submit(self.fetch_docs, jsessionid, page_id): page_id for page_id in page_ids} for future in concurrent.futures.as_completed(future_to_url): url = future_to_url[future] try: data = future.result() except Exception as exc: print('%r generated an exception: %s' % (url, exc))
def fetch_docs(self, jsessionid: str, page_id: str): print('start', page_id) s = requests.Session() s.cookies = cookiejar_from_dict({ 'JSESSIONID': jsessionid }) loader = ConfluenceLoader( url=self.wiki_url, session= s, cloud=False, page_ids=[page_id,], limit=1, max_pages=99999999 ) documents = loader.lazy_load() index = 0 for one in documents: print(one) with open(sys.path[0] + f'/doc/{page_id}_{index}.pkl', 'wb') as f: pickle.dump(one, f) index = index + 1
def gen_healthy_vectors(self): documents = [] for plk in pathlib.Path(sys.path[0] + '/doc').iterdir(): with open(sys.path[0] + f'/doc/{plk.name}', 'rb') as f: documents.append(pickle.load(f)) text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=100, add_start_index=True ) texts = text_splitter.split_documents(documents) print(len(texts)) self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding, persist_directory=sys.path[0] + '/persist')
def retreival_qa_chain(self): self.retriever = self.vectordb.as_retriever(search_kwargs={"k":8}) self.llm = ChatOllama(model='lgkt/llama3-chinese-alpaca', temperature=0.)
system_prompt = ( "You are an assistant for question-answering tasks. " "Use the following pieces of retrieved context to answer " "the question. If you don't know the answer, say that you " "don't know. Use three sentences maximum and keep the " "answer concise." "\n\n" "{context}" )
prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), ("human", "{input}"), ] )
question_answer_chain = create_stuff_documents_chain(self.llm, prompt) self.chain = create_retrieval_chain(self.retriever, question_answer_chain)
def answer_confluence(self,question:str) ->str: answer = self.chain.invoke({"input": question}) return answer
|