Refactoring: modular configuration, separated learning and response logic
This commit is contained in:
parent
9fa87dbf16
commit
0ecbfbadd0
@ -22,6 +22,7 @@ flake8 = "^6.0.0"
|
|||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
chadgpt = "chadgpt.main:main"
|
chadgpt = "chadgpt.main:main"
|
||||||
|
learn = "chadgpt.main:learn"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
@ -1,77 +1,14 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
from .config import DB_PATH
|
||||||
import gradio as gr
|
from .indexer import construct_index
|
||||||
from dotenv import load_dotenv
|
from .interface import iface
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from gpt_index import (
|
|
||||||
SimpleDirectoryReader,
|
|
||||||
GPTSimpleVectorIndex,
|
|
||||||
LLMPredictor,
|
|
||||||
PromptHelper
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_env():
|
def learn():
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
construct_index(DB_PATH)
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
# parse hidden api key:
|
|
||||||
def construct_index(directory_path):
|
|
||||||
# promt params:
|
|
||||||
max_input_size = 4096
|
|
||||||
num_outputs = 512
|
|
||||||
max_chunk_overlap = 20
|
|
||||||
chunk_size_limit = 600
|
|
||||||
|
|
||||||
prompt_helper = PromptHelper(
|
|
||||||
max_input_size,
|
|
||||||
num_outputs,
|
|
||||||
max_chunk_overlap,
|
|
||||||
chunk_size_limit=chunk_size_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = ChatOpenAI(
|
|
||||||
temperature=0.7,
|
|
||||||
model_name="gpt-3.5-turbo",
|
|
||||||
max_tokens=num_outputs
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_predictor = LLMPredictor(llm)
|
|
||||||
|
|
||||||
# get documents for learn:
|
|
||||||
documents = SimpleDirectoryReader(directory_path).load_data()
|
|
||||||
|
|
||||||
index = GPTSimpleVectorIndex(
|
|
||||||
documents,
|
|
||||||
llm_predictor=llm_predictor,
|
|
||||||
prompt_helper=prompt_helper
|
|
||||||
)
|
|
||||||
|
|
||||||
index_file = os.environ.get('DB_PATH') + "/index.json"
|
|
||||||
index.save_to_disk(index_file)
|
|
||||||
|
|
||||||
return index
|
|
||||||
|
|
||||||
|
|
||||||
def chatbot(input_text):
|
|
||||||
index_file = os.environ.get("DB_PATH") + "/index.json"
|
|
||||||
index = GPTSimpleVectorIndex.load_from_disk(index_file)
|
|
||||||
response = index.query(input_text, response_mode="compact")
|
|
||||||
return response.response
|
|
||||||
|
|
||||||
|
|
||||||
iface = gr.Interface(
|
|
||||||
fn=chatbot,
|
|
||||||
inputs=gr.components.Textbox(lines=7, label="Enter your text"),
|
|
||||||
outputs="text",
|
|
||||||
title="ISPsystem custom-trained AI Chatbot"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
get_env()
|
|
||||||
construct_index(os.environ.get("DB_PATH"))
|
|
||||||
iface.launch(share=False)
|
iface.launch(share=False)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user