diff --git a/pyproject.toml b/pyproject.toml index 4df96b5..558fbee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ flake8 = "^6.0.0" [tool.poetry.scripts] chadgpt = "chadgpt.main:main" +learn = "chadgpt.main:learn" [build-system] diff --git a/src/chadgpt/main.py b/src/chadgpt/main.py index ccb4895..01fa08c 100644 --- a/src/chadgpt/main.py +++ b/src/chadgpt/main.py @@ -1,77 +1,14 @@ #!/usr/bin/env python3 -import os -import gradio as gr -from dotenv import load_dotenv -from langchain.chat_models import ChatOpenAI -from gpt_index import ( - SimpleDirectoryReader, - GPTSimpleVectorIndex, - LLMPredictor, - PromptHelper -) +from .config import DB_PATH +from .indexer import construct_index +from .interface import iface -def get_env(): - if not os.environ.get("OPENAI_API_KEY"): - load_dotenv() - - -# parse hidden api key: -def construct_index(directory_path): - # promt params: - max_input_size = 4096 - num_outputs = 512 - max_chunk_overlap = 20 - chunk_size_limit = 600 - - prompt_helper = PromptHelper( - max_input_size, - num_outputs, - max_chunk_overlap, - chunk_size_limit=chunk_size_limit - ) - - llm = ChatOpenAI( - temperature=0.7, - model_name="gpt-3.5-turbo", - max_tokens=num_outputs - ) - - llm_predictor = LLMPredictor(llm) - - # get documents for learn: - documents = SimpleDirectoryReader(directory_path).load_data() - - index = GPTSimpleVectorIndex( - documents, - llm_predictor=llm_predictor, - prompt_helper=prompt_helper - ) - - index_file = os.environ.get('DB_PATH') + "/index.json" - index.save_to_disk(index_file) - - return index - - -def chatbot(input_text): - index_file = os.environ.get("DB_PATH") + "/index.json" - index = GPTSimpleVectorIndex.load_from_disk(index_file) - response = index.query(input_text, response_mode="compact") - return response.response - - -iface = gr.Interface( - fn=chatbot, - inputs=gr.components.Textbox(lines=7, label="Enter your text"), - outputs="text", - title="ISPsystem custom-trained AI Chatbot" -) +def learn(): + construct_index(DB_PATH) def main(): - get_env() - construct_index(os.environ.get("DB_PATH")) iface.launch(share=False)