Refactoring: modular configuration, separated learning and response logic
This commit is contained in:
		
							parent
							
								
									9fa87dbf16
								
							
						
					
					
						commit
						0ecbfbadd0
					
				@ -22,6 +22,7 @@ flake8 = "^6.0.0"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
[tool.poetry.scripts]
 | 
					[tool.poetry.scripts]
 | 
				
			||||||
chadgpt = "chadgpt.main:main"
 | 
					chadgpt = "chadgpt.main:main"
 | 
				
			||||||
 | 
					learn = "chadgpt.main:learn"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[build-system]
 | 
					[build-system]
 | 
				
			||||||
 | 
				
			|||||||
@ -1,77 +1,14 @@
 | 
				
			|||||||
#!/usr/bin/env python3
 | 
					#!/usr/bin/env python3
 | 
				
			||||||
import os
 | 
					from .config import DB_PATH
 | 
				
			||||||
import gradio as gr
 | 
					from .indexer import construct_index
 | 
				
			||||||
from dotenv import load_dotenv
 | 
					from .interface import iface
 | 
				
			||||||
from langchain.chat_models import ChatOpenAI
 | 
					 | 
				
			||||||
from gpt_index import (
 | 
					 | 
				
			||||||
    SimpleDirectoryReader,
 | 
					 | 
				
			||||||
    GPTSimpleVectorIndex,
 | 
					 | 
				
			||||||
    LLMPredictor,
 | 
					 | 
				
			||||||
    PromptHelper
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_env():
 | 
					def learn():
 | 
				
			||||||
    if not os.environ.get("OPENAI_API_KEY"):
 | 
					    construct_index(DB_PATH)
 | 
				
			||||||
        load_dotenv()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# parse hidden api key:
 | 
					 | 
				
			||||||
def construct_index(directory_path):
 | 
					 | 
				
			||||||
    # promt params:
 | 
					 | 
				
			||||||
    max_input_size = 4096
 | 
					 | 
				
			||||||
    num_outputs = 512
 | 
					 | 
				
			||||||
    max_chunk_overlap = 20
 | 
					 | 
				
			||||||
    chunk_size_limit = 600
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    prompt_helper = PromptHelper(
 | 
					 | 
				
			||||||
        max_input_size,
 | 
					 | 
				
			||||||
        num_outputs,
 | 
					 | 
				
			||||||
        max_chunk_overlap,
 | 
					 | 
				
			||||||
        chunk_size_limit=chunk_size_limit
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    llm = ChatOpenAI(
 | 
					 | 
				
			||||||
        temperature=0.7,
 | 
					 | 
				
			||||||
        model_name="gpt-3.5-turbo",
 | 
					 | 
				
			||||||
        max_tokens=num_outputs
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    llm_predictor = LLMPredictor(llm)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # get documents for learn:
 | 
					 | 
				
			||||||
    documents = SimpleDirectoryReader(directory_path).load_data()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    index = GPTSimpleVectorIndex(
 | 
					 | 
				
			||||||
        documents,
 | 
					 | 
				
			||||||
        llm_predictor=llm_predictor,
 | 
					 | 
				
			||||||
        prompt_helper=prompt_helper
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    index_file = os.environ.get('DB_PATH') + "/index.json"
 | 
					 | 
				
			||||||
    index.save_to_disk(index_file)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return index
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def chatbot(input_text):
 | 
					 | 
				
			||||||
    index_file = os.environ.get("DB_PATH") + "/index.json"
 | 
					 | 
				
			||||||
    index = GPTSimpleVectorIndex.load_from_disk(index_file)
 | 
					 | 
				
			||||||
    response = index.query(input_text, response_mode="compact")
 | 
					 | 
				
			||||||
    return response.response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
iface = gr.Interface(
 | 
					 | 
				
			||||||
    fn=chatbot,
 | 
					 | 
				
			||||||
    inputs=gr.components.Textbox(lines=7, label="Enter your text"),
 | 
					 | 
				
			||||||
    outputs="text",
 | 
					 | 
				
			||||||
    title="ISPsystem custom-trained AI Chatbot"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    get_env()
 | 
					 | 
				
			||||||
    construct_index(os.environ.get("DB_PATH"))
 | 
					 | 
				
			||||||
    iface.launch(share=False)
 | 
					    iface.launch(share=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user