diff --git a/age b/age new file mode 100644 index 0000000..e69de29 diff --git a/poetry.lock b/poetry.lock index c994f6a..6936413 100644 --- a/poetry.lock +++ b/poetry.lock @@ -446,7 +446,7 @@ dev = ["flake8", "hypothesis", "ipython", "mypy (>=0.710)", "portray", "pytest ( name = "diskcache" version = "5.6.1" description = "Disk Cache -- Disk and file backed persistent cache." -optional = false +optional = true python-versions = ">=3" files = [ {file = "diskcache-5.6.1-py3-none-any.whl", hash = "sha256:558c6a2d5d7c721bb00e40711803d6804850c9f76c426ed81ecc627fe9d2ce2d"}, @@ -677,7 +677,7 @@ gitdb = ">=4.0.1,<5" name = "gpt4all" version = "1.0.1" description = "Python bindings for GPT4All" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "gpt4all-1.0.1-py3-none-macosx_10_9_universal2.whl", hash = "sha256:e400f99735fe5a1fe6b5fe6acaaa829a1aaa60c16e1ede433d57792235a3fdcf"}, @@ -1031,7 +1031,7 @@ requests = ">=2,<3" name = "llama-cpp-python" version = "0.1.68" description = "A Python wrapper for llama.cpp" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "llama_cpp_python-0.1.68.tar.gz", hash = "sha256:619ca317d771fc0c30ceba68c29c318287cd1cae2eaa14661aec675190295f19"}, @@ -1493,7 +1493,7 @@ files = [ name = "openai" version = "0.27.8" description = "Python client library for the OpenAI API" -optional = false +optional = true python-versions = ">=3.7.1" files = [ {file = "openai-0.27.8-py3-none-any.whl", hash = "sha256:e0a7c2f7da26bdbe5354b03c6d4b82a2f34bd4458c7a17ae1a7092c3e397e03c"}, @@ -2477,7 +2477,7 @@ files = [ name = "tiktoken" version = "0.4.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "tiktoken-0.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:176cad7f053d2cc82ce7e2a7c883ccc6971840a4b5276740d0b732a2b2011f8a"}, @@ -2962,7 +2962,11 @@ files = [ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] +[extras] +all = ["gpt4all", "llama-cpp-python", "openai", "tiktoken"] +local = ["gpt4all", "llama-cpp-python"] + [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "aa717ea6c7c359d5d767c86b792d14b57fe804f7abe03c811612fc02ada794ea" +content-hash = "ab2cc94ae91acd592386537099f6cf9c23a2a856e6a37753605a5b1c1df423c2" diff --git a/requirements.txt b/requirements.txt index d30731a..109e091 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ gitdb==4.0.10 GitPython==3.1.41 gpt4all==0.2.3 halo==0.0.31 -huggingface-hub==0.14.1 +huggingface-hub idna==3.7 Jinja2==3.1.3 joblib==1.2.0 @@ -32,7 +32,7 @@ mypy-extensions==1.0.0 networkx==3.1 nltk==3.8.1 numexpr==2.8.4 -numpy==1.24.3 +numpy==1.26.2 openai==0.27.7 openapi-schema-pydantic==1.2.4 packaging==23.1 @@ -56,7 +56,7 @@ tenacity==8.2.2 termcolor==2.3.0 threadpoolctl==3.1.0 tiktoken==0.4.0 -tokenizers==0.13.3 +tokenizers torch==2.0.1 torchvision==0.15.2 tqdm==4.65.0 diff --git a/talk_codebase/cli.py b/talk_codebase/cli.py index 063154c..cdceb65 100644 --- a/talk_codebase/cli.py +++ b/talk_codebase/cli.py @@ -1,10 +1,12 @@ import sys +import subprocess +import requests +import shutil import fire -from talk_codebase.config import CONFIGURE_STEPS, save_config, get_config, config_path, remove_api_key, \ - remove_model_type, remove_model_name_local -from talk_codebase.consts import DEFAULT_CONFIG +from talk_codebase.config import CONFIGURE_STEPS, save_config, get_config, config_path, configure, remove_configuration +from talk_codebase.consts import DEFAULT_CONFIG, MODEL_TYPES from talk_codebase.llm import factory_llm from talk_codebase.utils import get_repo @@ -15,6 +17,18 @@ def check_python_version(): sys.exit(1) +def check_ollama_installed(): + return shutil.which("ollama") is not None + + +def check_ollama_running(): + try: + response = requests.get("http://localhost:11434/api/tags") + return response.status_code == 200 + except requests.RequestException: + return False + + def update_config(config): for key, value in DEFAULT_CONFIG.items(): if key not in config: @@ -22,19 +36,8 @@ def update_config(config): return config -def configure(reset=True): - if reset: - remove_api_key() - remove_model_type() - remove_model_name_local() - config = get_config() - config = update_config(config) - for step in CONFIGURE_STEPS: - step(config) - save_config(config) - - def chat_loop(llm): + print("\nšŸ¤– I'm here to help you understand the codebase. Feel free to ask any questions!") while True: query = input("šŸ‘‰ ").lower().strip() if not query: @@ -42,18 +45,55 @@ def chat_loop(llm): continue if query in ('exit', 'quit'): break + print("\nšŸ¤– Analyzing the codebase to provide the best possible answer...") llm.send_query(query) def chat(): - configure(False) config = get_config() + if not config.get("embedding_model_type") or not config.get("chat_model_type"): + print("šŸ¤– Configuration not found. Running configuration process...") + configure(False) + config = get_config() + repo = get_repo() if not repo: print("šŸ¤– Git repository not found") sys.exit(1) - llm = factory_llm(repo.working_dir, config) - chat_loop(llm) + + if config.get("embedding_model_type") != config.get("chat_model_type"): + print("Error: Embedding and chat model types must be the same.") + print("Please run 'talk-codebase configure' to set up your configuration correctly.") + sys.exit(1) + + model_type = config.get("embedding_model_type") + + if model_type in [MODEL_TYPES["OPENAI"], MODEL_TYPES["OPENAI_COMPATIBLE"]]: + if not config.get("openai_compatible_api_key"): + print("Error: API key is missing. Please run 'talk-codebase configure' to set up your API key.") + sys.exit(1) + + if model_type == MODEL_TYPES["OPENAI_COMPATIBLE"] and not config.get("openai_compatible_endpoint"): + print("Error: API endpoint is missing for OpenAI-compatible setup. Please run 'talk-codebase configure' to set up your API endpoint.") + sys.exit(1) + + elif model_type == MODEL_TYPES["OLLAMA"]: + if not check_ollama_installed(): + print("āš ļø Ollama is not found in PATH. Please ensure Ollama is installed and added to your system PATH.") + print("You can download Ollama from: https://ollama.ai/download") + sys.exit(1) + + if not check_ollama_running(): + print("āš ļø Ollama is installed but not running. Please start Ollama with 'ollama serve' command.") + sys.exit(1) + + try: + llm = factory_llm(repo.working_dir, config) + chat_loop(llm) + except ValueError as e: + print(f"Error: {str(e)}") + print("Please run 'talk-codebase configure' to set up your configuration correctly.") + sys.exit(1) def main(): diff --git a/talk_codebase/config.py b/talk_codebase/config.py index 1ae132b..bdbd385 100644 --- a/talk_codebase/config.py +++ b/talk_codebase/config.py @@ -1,174 +1,158 @@ import os - -openai_flag = True - -try: - import openai -except: - openai_flag = False - -import gpt4all import questionary import yaml +import requests from talk_codebase.consts import MODEL_TYPES - - - config_path = os.path.join(os.path.expanduser("~"), ".talk_codebase_config.yaml") - def get_config(): if os.path.exists(config_path): with open(config_path, "r") as f: config = yaml.safe_load(f) else: config = {} + + # Set default value for frequency_penalty if not present + if 'frequency_penalty' not in config: + config['frequency_penalty'] = 0.0 + return config - def save_config(config): with open(config_path, "w") as f: yaml.dump(config, f) - -def api_key_is_invalid(api_key): - if not api_key: - return True - try: - openai.api_key = api_key - openai.Engine.list() - except Exception: - return True - return False - - -def get_gpt_models(openai): +def get_ollama_models(purpose): try: - model_lst = openai.Model.list() - except Exception: - print("✘ Failed to retrieve model list") - return [] - - return [i['id'] for i in model_lst['data'] if 'gpt' in i['id']] - - -def configure_model_name_openai(config): - api_key = config.get("api_key") - - if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("openai_model_name"): - return - - openai.api_key = api_key - gpt_models = get_gpt_models(openai) - choices = [{"name": model, "value": model} for model in gpt_models] - - if not choices: - print("ℹ No GPT models available") - return + if purpose == "chat": + response = requests.get("http://localhost:11434/api/tags") + if response.status_code == 200: + models = response.json() + return [model['name'] for model in models['models']] + elif purpose == "embedding": + # For now, we'll use a predefined list of Ollama embedding models + return ["nomic-embed-text", "all-MiniLM-L6-v2"] + except requests.RequestException: + print(f"Error: Unable to fetch Ollama {purpose} models. Make sure Ollama is running.") + return [] + +def configure_api_key(config, model_type): + if model_type in [MODEL_TYPES["OPENAI"], MODEL_TYPES["OPENAI_COMPATIBLE"]]: + api_key = questionary.password("šŸ¤– Enter your API key:").ask() + config["openai_compatible_api_key"] = api_key + save_config(config) + print("API key saved successfully.") - model_name = questionary.select("šŸ¤– Select model name:", choices).ask() +def configure_api_endpoint(config, purpose): + if config.get(f"{purpose}_model_type") == MODEL_TYPES["OPENAI_COMPATIBLE"]: + endpoint = questionary.text("šŸ¤– Enter the API endpoint:").ask() + config["openai_compatible_endpoint"] = endpoint + save_config(config) + print(f"API endpoint for {purpose} set to: {endpoint}") + elif config.get(f"{purpose}_model_type") == MODEL_TYPES["OLLAMA"]: + if purpose == "chat": + config[f"{purpose}_api_endpoint"] = "http://localhost:11434/api/generate" + elif purpose == "embedding": + config[f"{purpose}_api_endpoint"] = "http://localhost:11434/api/embeddings" + save_config(config) + print(f"Ollama API endpoint for {purpose} set to: {config[f'{purpose}_api_endpoint']}") - if not model_name: - print("✘ No model selected") +def configure_model_name(config, purpose): + model_type = config.get(f"{purpose}_model_type") + + if model_type == MODEL_TYPES["OLLAMA"]: + ollama_models = get_ollama_models(purpose) + if not ollama_models: + print(f"āŒ No Ollama models found for {purpose}. Please make sure Ollama is running and you have pulled some models.") + return + choices = [{"name": model, "value": model} for model in ollama_models] + model_name = questionary.select(f"šŸ¤– Select Ollama model for {purpose}:", choices).ask() + elif model_type in [MODEL_TYPES["OPENAI"], MODEL_TYPES["OPENAI_COMPATIBLE"]]: + prompt = f"šŸ¤– Enter the model name for {purpose} (e.g., text-embedding-ada-002 for embedding, gpt-3.5-turbo for chat):" + model_name = questionary.text(prompt).ask() + else: + print(f"Invalid model type: {model_type}") return - config["openai_model_name"] = model_name - save_config(config) - print("šŸ¤– Model name saved!") - - -def remove_model_name_openai(): - config = get_config() - config["openai_model_name"] = None + config[f"{purpose}_model_name"] = model_name save_config(config) + print(f"šŸ¤– {purpose.capitalize()} model name saved!") - -def configure_model_name_local(config): - if config.get("model_type") != MODEL_TYPES["LOCAL"] or config.get("local_model_name"): - return - - list_models = gpt4all.GPT4All.list_models() - - def get_model_info(model): - return ( - f"{model['name']} " - f"| {model['filename']} " - f"| {model['filesize']} " - f"| {model['parameters']} " - f"| {model['quant']} " - f"| {model['type']}" - ) - +def configure_section(config, purpose): choices = [ - {"name": get_model_info(model), "value": model['filename']} for model in list_models + {"name": "OpenAI", "value": MODEL_TYPES["OPENAI"]}, + {"name": "OpenAI Compatible", "value": MODEL_TYPES["OPENAI_COMPATIBLE"]}, + {"name": "Ollama (Local)", "value": MODEL_TYPES["OLLAMA"]} ] - model_name = questionary.select("šŸ¤– Select model name:", choices).ask() - config["local_model_name"] = model_name - save_config(config) - print("šŸ¤– Model name saved!") - + model_type = questionary.select( + f"šŸ¤– Select model type for {purpose}:", + choices=choices + ).ask() -def remove_model_name_local(): - config = get_config() - config["local_model_name"] = None + config[f"{purpose}_model_type"] = model_type save_config(config) + if model_type in [MODEL_TYPES["OPENAI"], MODEL_TYPES["OPENAI_COMPATIBLE"]]: + configure_api_key(config, model_type) + if model_type == MODEL_TYPES["OPENAI_COMPATIBLE"]: + configure_api_endpoint(config, purpose) + elif model_type == MODEL_TYPES["OLLAMA"]: + configure_api_endpoint(config, purpose) -def get_and_validate_api_key(): - prompt = "šŸ¤– Enter your OpenAI API key: " - api_key = input(prompt) - while api_key_is_invalid(api_key): - print("✘ Invalid API key") - api_key = input(prompt) - return api_key + configure_model_name(config, purpose) - -def configure_api_key(config): - if config.get("model_type") != MODEL_TYPES["OPENAI"]: - return - - if api_key_is_invalid(config.get("api_key")): - api_key = get_and_validate_api_key() - config["api_key"] = api_key - save_config(config) - - -def remove_api_key(): +def configure_embedding(): config = get_config() - config["api_key"] = None - save_config(config) - + configure_section(config, "embedding") -def remove_model_type(): +def configure_chat(): config = get_config() - config["model_type"] = None - save_config(config) - - -def configure_model_type(config): - if config.get("model_type"): - return + configure_section(config, "chat") - choices = [{"name": "Local", "value": MODEL_TYPES["LOCAL"]}] - - if openai_flag: choices.append( - {"name": "OpenAI", "value": MODEL_TYPES["OPENAI"]}) - - - model_type = questionary.select( - "šŸ¤– Select model type:", - choices=choices + # Add configuration for frequency_penalty + frequency_penalty = questionary.text( + "šŸ¤– Enter the frequency penalty (default is 0.0, range is -2.0 to 2.0):", + default="0.0" ).ask() - config["model_type"] = model_type + try: + frequency_penalty = float(frequency_penalty) + if -2.0 <= frequency_penalty <= 2.0: + config["frequency_penalty"] = frequency_penalty + save_config(config) + print(f"Frequency penalty set to: {frequency_penalty}") + else: + print("Invalid frequency penalty value. Using default (0.0).") + except ValueError: + print("Invalid input. Using default frequency penalty (0.0).") + +def remove_configuration(): + config = get_config() + keys_to_remove = [ + "embedding_model_type", "embedding_model_name", + "chat_model_type", "chat_model_name", + "openai_compatible_api_key", "openai_compatible_endpoint", + "embedding_api_endpoint", "chat_api_endpoint", + "frequency_penalty" + ] + for key in keys_to_remove: + config.pop(key, None) save_config(config) - + print("Configuration removed successfully.") CONFIGURE_STEPS = [ - configure_model_type, - configure_api_key, - configure_model_name_openai, - configure_model_name_local, + configure_embedding, + configure_chat, ] + +def configure(reset=False): + if reset: + remove_configuration() + + config = get_config() + for step in CONFIGURE_STEPS: + step() + + print("Configuration completed successfully.") diff --git a/talk_codebase/consts.py b/talk_codebase/consts.py index 86e170f..acd96f1 100644 --- a/talk_codebase/consts.py +++ b/talk_codebase/consts.py @@ -12,9 +12,9 @@ EXCLUDE_FILES = ['requirements.txt', 'package.json', 'package-lock.json', 'yarn.lock'] MODEL_TYPES = { "OPENAI": "openai", - "LOCAL": "local", + "OPENAI_COMPATIBLE": "openai_compatible", + "OLLAMA": "ollama" } -DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") DEFAULT_CONFIG = { "max_tokens": "2056", @@ -22,8 +22,6 @@ "chunk_overlap": "256", "k": "2", "temperature": "0.7", - "model_path": DEFAULT_MODEL_DIRECTORY, - "n_batch": "8", } LOADER_MAPPING = { diff --git a/talk_codebase/llm.py b/talk_codebase/llm.py index 66316e9..85184b5 100644 --- a/talk_codebase/llm.py +++ b/talk_codebase/llm.py @@ -1,55 +1,120 @@ import os import time -from typing import Optional +import logging +from typing import Optional, List, Dict, Any -import gpt4all import questionary +import requests from halo import Halo from langchain.vectorstores import FAISS from langchain.callbacks.manager import CallbackManager from langchain.chains import RetrievalQA from langchain.chat_models import ChatOpenAI -from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings -from langchain.llms import LlamaCpp +from langchain.embeddings import OpenAIEmbeddings +from langchain.embeddings.base import Embeddings +from langchain.llms.base import LLM +from langchain.schema import BaseMessage, HumanMessage, AIMessage from langchain.text_splitter import RecursiveCharacterTextSplitter +from pydantic import Field from talk_codebase.consts import MODEL_TYPES from talk_codebase.utils import load_files, get_local_vector_store, calculate_cost, StreamStdOut +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class OllamaEmbeddings(Embeddings): + def __init__(self, model: str, api_url: str): + self.model = model + self.api_url = api_url + logger.info(f"Ollama Embeddings API URL: {self.api_url}") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + embeddings = [] + for text in texts: + response = requests.post( + self.api_url, + json={"model": self.model, "prompt": text} + ) + embeddings.append(response.json()['embedding']) + return embeddings + + def embed_query(self, text: str) -> List[float]: + response = requests.post( + self.api_url, + json={"model": self.model, "prompt": text} + ) + return response.json()['embedding'] + + +class OllamaChatModel(LLM): + model: str = Field(..., description="The name of the Ollama model to use") + api_url: str = Field(..., description="The API URL for the Ollama service") + + def __init__(self, model: str, api_url: str): + super().__init__() + self.model = model + self.api_url = api_url + logger.info(f"Ollama Chat API URL: {self.api_url}") + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + response = requests.post( + self.api_url, + json={"model": self.model, "prompt": prompt, "stream": False} + ) + return response.json()['response'] + + @property + def _llm_type(self) -> str: + return "ollama" + class BaseLLM: def __init__(self, root_dir, config): self.config = config - self.llm = self._create_model() self.root_dir = root_dir + self.embedding_model = self._create_embedding_model() + self.chat_model = self._create_chat_model() + logger.info("Creating vector store...") self.vector_store = self._create_store(root_dir) + logger.info("Vector store created successfully.") def _create_store(self, root_dir): raise NotImplementedError("Subclasses must implement this method.") - def _create_model(self): + def _create_embedding_model(self): + raise NotImplementedError("Subclasses must implement this method.") + + def _create_chat_model(self): raise NotImplementedError("Subclasses must implement this method.") def embedding_search(self, query, k): return self.vector_store.search(query, k=k, search_type="similarity") def _create_vector_store(self, embeddings, index, root_dir): - k = int(self.config.get("k")) + k = int(self.config.get("k", 2)) index_path = os.path.join(root_dir, f"vector_store/{index}") new_db = get_local_vector_store(embeddings, index_path) if new_db is not None: + logger.info("Using existing vector store.") return new_db.as_retriever(search_kwargs={"k": k}) + logger.info("Creating new vector store...") docs = load_files() if len(docs) == 0: - print("✘ No documents found") + logger.error("No documents found") exit(0) - text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(self.config.get("chunk_size")), - chunk_overlap=int(self.config.get("chunk_overlap"))) + text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(self.config.get("chunk_size", 2056)), + chunk_overlap=int(self.config.get("chunk_overlap", 256)), + separators=["\n\n", "\n", " ", ""]) texts = text_splitter.split_documents(docs) - if index == MODEL_TYPES["OPENAI"]: - cost = calculate_cost(docs, self.config.get("openai_model_name")) + + model_type = self.config.get("embedding_model_type") + cost = calculate_cost(texts, self.config.get("embedding_model_name"), model_type) + + if cost > 0: approve = questionary.select( f"Creating a vector store will cost ~${cost:.5f}. Do you want to continue?", choices=[ @@ -69,57 +134,125 @@ def _create_vector_store(self, embeddings, index, root_dir): time.sleep(1.5) spinners.succeed(f"Created vector store") + logger.info("New vector store created successfully.") return db.as_retriever(search_kwargs={"k": k}) def send_query(self, query): - retriever = self._create_store(self.root_dir) + logger.info(f"Processing query: {query}") qa = RetrievalQA.from_chain_type( - llm=self.llm, + llm=self.chat_model, chain_type="stuff", - retriever=retriever, + retriever=self.vector_store, return_source_documents=True ) - docs = qa(query) + # Add a custom prompt to encourage more relevant responses + custom_prompt = f""" + You are an AI assistant specialized in analyzing codebases. + Given the following query about the codebase, provide the most relevant and helpful response possible. + If you're not entirely sure, make an educated guess based on the context of the codebase. + Query: {query} + """ + docs = qa({"query": custom_prompt}) + + # Print the response + print("\nšŸ¤– Response:") + print(docs['result']) + + # Print the source files file_paths = [os.path.abspath(s.metadata["source"]) for s in docs['source_documents']] - print('\n'.join([f'šŸ“„ {file_path}:' for file_path in file_paths])) + print('\nšŸ“ Source files:') + print('\n'.join([f'- {file_path}' for file_path in file_paths])) + logger.info("Query processed successfully.") -class LocalLLM(BaseLLM): +class OllamaLLM(BaseLLM): def _create_store(self, root_dir: str) -> Optional[FAISS]: - embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2') - return self._create_vector_store(embeddings, MODEL_TYPES["LOCAL"], root_dir) - - def _create_model(self): - os.makedirs(self.config.get("model_path"), exist_ok=True) - gpt4all.GPT4All.retrieve_model(model_name=self.config.get("local_model_name"), - model_path=self.config.get("model_path")) - model_path = os.path.join(self.config.get("model_path"), self.config.get("local_model_name")) - model_n_ctx = int(self.config.get("max_tokens")) - model_n_batch = int(self.config.get("n_batch")) - callbacks = CallbackManager([StreamStdOut()]) - llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, - verbose=False) - llm.client.verbose = False - return llm + return self._create_vector_store(self.embedding_model, MODEL_TYPES["OLLAMA"], root_dir) + + def _create_embedding_model(self): + embedding_model = self.config.get("embedding_model_name") + embedding_api_url = self.config.get("embedding_api_endpoint") + logger.info(f"Creating Ollama Embedding model: {embedding_model} with API URL: {embedding_api_url}") + return OllamaEmbeddings( + model=embedding_model, + api_url=embedding_api_url + ) + + def _create_chat_model(self): + chat_model = self.config.get("chat_model_name") + chat_api_url = self.config.get("chat_api_endpoint") + logger.info(f"Creating Ollama Chat model: {chat_model} with API URL: {chat_api_url}") + return OllamaChatModel( + model=chat_model, + api_url=chat_api_url + ) class OpenAILLM(BaseLLM): def _create_store(self, root_dir: str) -> Optional[FAISS]: - embeddings = OpenAIEmbeddings(openai_api_key=self.config.get("api_key")) - return self._create_vector_store(embeddings, MODEL_TYPES["OPENAI"], root_dir) + return self._create_vector_store(self.embedding_model, MODEL_TYPES["OPENAI"], root_dir) + + def _create_embedding_model(self): + logger.info("Creating OpenAI Embedding model") + return OpenAIEmbeddings( + model=self.config.get("embedding_model_name"), + openai_api_key=self.config.get("openai_compatible_api_key") + ) + + def _create_chat_model(self): + logger.info("Creating OpenAI Chat model") + return ChatOpenAI( + model_name=self.config.get("chat_model_name"), + openai_api_key=self.config.get("openai_compatible_api_key"), + streaming=True, + max_tokens=int(self.config.get("max_tokens", 2056)), + callback_manager=CallbackManager([StreamStdOut()]), + temperature=float(self.config.get("temperature", 0.7)), + presence_penalty=0.6, # Encourage the model to talk about new topics + frequency_penalty=float(self.config.get("frequency_penalty", 0.0)) # Use configured value or default to 0.0 + ) + + +class OpenAICompatibleLLM(BaseLLM): + def _create_store(self, root_dir: str) -> Optional[FAISS]: + return self._create_vector_store(self.embedding_model, MODEL_TYPES["OPENAI_COMPATIBLE"], root_dir) - def _create_model(self): - return ChatOpenAI(model_name=self.config.get("openai_model_name"), - openai_api_key=self.config.get("api_key"), - streaming=True, - max_tokens=int(self.config.get("max_tokens")), - callback_manager=CallbackManager([StreamStdOut()]), - temperature=float(self.config.get("temperature"))) + def _create_embedding_model(self): + logger.info("Creating OpenAI Compatible Embedding model") + return OpenAIEmbeddings( + model=self.config.get("embedding_model_name"), + openai_api_key=self.config.get("openai_compatible_api_key"), + openai_api_base=self.config.get("openai_compatible_endpoint") + ) + + def _create_chat_model(self): + logger.info("Creating OpenAI Compatible Chat model") + return ChatOpenAI( + model_name=self.config.get("chat_model_name"), + openai_api_key=self.config.get("openai_compatible_api_key"), + openai_api_base=self.config.get("openai_compatible_endpoint"), + streaming=True, + max_tokens=int(self.config.get("max_tokens", 2056)), + callback_manager=CallbackManager([StreamStdOut()]), + temperature=float(self.config.get("temperature", 0.7)), + frequency_penalty=float(self.config.get("frequency_penalty", 0.0)) # Use configured value or default to 0.0 + ) def factory_llm(root_dir, config): - if config.get("model_type") == "openai": + embedding_type = config.get("embedding_model_type") + chat_type = config.get("chat_model_type") + + if embedding_type != chat_type: + raise ValueError("Embedding and chat model types must be the same.") + + logger.info(f"Creating LLM of type: {embedding_type}") + if embedding_type == MODEL_TYPES["OPENAI"]: return OpenAILLM(root_dir, config) + elif embedding_type == MODEL_TYPES["OPENAI_COMPATIBLE"]: + return OpenAICompatibleLLM(root_dir, config) + elif embedding_type == MODEL_TYPES["OLLAMA"]: + return OllamaLLM(root_dir, config) else: - return LocalLLM(root_dir, config) + raise ValueError(f"Invalid model type: {embedding_type}") diff --git a/talk_codebase/utils.py b/talk_codebase/utils.py index 29a076b..7f2e334 100644 --- a/talk_codebase/utils.py +++ b/talk_codebase/utils.py @@ -5,7 +5,7 @@ from langchain.vectorstores import FAISS from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES +from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES, MODEL_TYPES def get_repo(): @@ -48,13 +48,20 @@ def load_files(): return files -def calculate_cost(texts, model_name): - enc = tiktoken.encoding_for_model(model_name) - all_text = ''.join([text.page_content for text in texts]) - tokens = enc.encode(all_text) - token_count = len(tokens) - cost = (token_count / 1000) * 0.0004 - return cost +def calculate_cost(texts, model_name, model_type): + if model_type == MODEL_TYPES["OLLAMA"]: + return 0 # No cost for Ollama models + + try: + enc = tiktoken.encoding_for_model(model_name) + all_text = ''.join([text.page_content for text in texts]) + tokens = enc.encode(all_text) + token_count = len(tokens) + cost = (token_count / 1000) * 0.0004 + return cost + except KeyError: + print(f"Warning: Unable to calculate cost for model {model_name}. Assuming no cost.") + return 0 def get_local_vector_store(embeddings, path):