-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
119 lines (92 loc) · 4.09 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import re
from glob import glob
import sys
import boto3
from langchain_community.embeddings import BedrockEmbeddings
from langchain.llms.bedrock import Bedrock
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_community.vectorstores.faiss import FAISS
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.chains import (
StuffDocumentsChain, LLMChain, ConversationalRetrievalChain
)
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
import chainlit as cl
from config import aws_access_key, aws_region_name, aws_secret_key
def create_client():
bedrock = boto3.client(service_name='bedrock-runtime',
region_name=aws_region_name,
aws_access_key_id =aws_access_key,
aws_secret_access_key =aws_secret_key
)
return bedrock
def create_llm(bedrock_client):
# load llama2
llm = Bedrock(model_id='meta.llama2-13b-chat-v1',
client=bedrock_client,
streaming=True,
model_kwargs={'temperature':0, 'top_p':0.9})
return llm
def create_prompt():
prompt_template = """
If the question is not relevant to the provided documents, respond with "I don't know" or "This question is outside the bounds of the data I am trained on".
{context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question'])
return prompt
@cl.on_chat_start
async def create_qa_chain():
# create client
bedrock_client = create_client()
# load llm
llm = create_llm(bedrock_client=bedrock_client)
# load embeddings and vector store
bedrock_embeddings=BedrockEmbeddings(model_id='amazon.titan-embed-text-v1', client=bedrock_client)
vector_store = FAISS.load_local('faiss_index', bedrock_embeddings, allow_dangerous_deserialization=True)
# create memory history
message_history = ChatMessageHistory()
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key="answer",
chat_memory=message_history,
return_messages=True,
)
# create qa chain
qa_chain = ConversationalRetrievalChain.from_llm(llm,
chain_type='stuff',
retriever=vector_store.as_retriever(search_type='similarity', search_kwargs={"k":3}),
return_source_documents=True,
memory=memory
)
# add custom messages
msg = cl.Message(content="Loading the bot...")
await msg.send()
msg.content = "Hi, Welcome to the QA Chatbot! Please ask your question."
await msg.update()
cl.user_session.set('qa_chain' ,qa_chain)
@cl.on_message
async def generate_response(query):
qa_chain = cl.user_session.get('qa_chain')
res = await qa_chain.acall(query.content, callbacks=[cl.AsyncLangchainCallbackHandler(
stream_final_answer=True,
#answer_prefix_tokens= ["Final", "Answer"]
)])
# extract results and source documents
result, source_documents = res['answer'], res['source_documents']
# Extract all values associated with the 'metadata' key
source_documents = str(source_documents)
metadata_values = re.findall(r"metadata={'source': '([^']*)', 'page': (\d+)}", source_documents)
# Convert metadata_values into a single string
pattern = r'PDF Documents|\\'
metadata_string = "\n".join([f"Source: {re.sub(pattern, '', source)}, page: {page}" for source, page in metadata_values])
# add metadata (i.e., sources) to the results
result += f'\n\n{metadata_string}'
# send the generated response to the user
await cl.Message(content=result).send()