first commit
This commit is contained in:
25
agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py
Normal file
25
agent_dhal/agentdhal_extensions/tools/graphrag/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from ._config import (
|
||||
GlobalContextConfig,
|
||||
GlobalDataConfig,
|
||||
LocalContextConfig,
|
||||
LocalDataConfig,
|
||||
MapReduceConfig,
|
||||
SearchConfig,
|
||||
)
|
||||
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn
|
||||
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn
|
||||
|
||||
__all__ = [
|
||||
"GlobalSearchTool",
|
||||
"LocalSearchTool",
|
||||
"GlobalDataConfig",
|
||||
"LocalDataConfig",
|
||||
"GlobalContextConfig",
|
||||
"GlobalSearchToolArgs",
|
||||
"GlobalSearchToolReturn",
|
||||
"LocalContextConfig",
|
||||
"LocalSearchToolArgs",
|
||||
"LocalSearchToolReturn",
|
||||
"MapReduceConfig",
|
||||
"SearchConfig",
|
||||
]
|
||||
59
agent_dhal/agentdhal_extensions/tools/graphrag/_config.py
Normal file
59
agent_dhal/agentdhal_extensions/tools/graphrag/_config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
input_dir: str
|
||||
entity_table: str = "entities"
|
||||
entity_embedding_table: str = "entities"
|
||||
community_table: str = "communities"
|
||||
community_level: int = 2
|
||||
|
||||
|
||||
class GlobalDataConfig(DataConfig):
|
||||
community_report_table: str = "community_reports"
|
||||
|
||||
|
||||
class LocalDataConfig(DataConfig):
|
||||
relationship_table: str = "relationships"
|
||||
text_unit_table: str = "text_units"
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
max_data_tokens: int = 8000
|
||||
|
||||
|
||||
class GlobalContextConfig(ContextConfig):
|
||||
use_community_summary: bool = False
|
||||
shuffle_data: bool = True
|
||||
include_community_rank: bool = True
|
||||
min_community_rank: int = 0
|
||||
community_rank_name: str = "rank"
|
||||
include_community_weight: bool = True
|
||||
community_weight_name: str = "occurrence weight"
|
||||
normalize_community_weight: bool = True
|
||||
max_data_tokens: int = 12000
|
||||
|
||||
|
||||
class LocalContextConfig(ContextConfig):
|
||||
text_unit_prop: float = 0.5
|
||||
community_prop: float = 0.25
|
||||
include_entity_rank: bool = True
|
||||
rank_description: str = "number of relationships"
|
||||
include_relationship_weight: bool = True
|
||||
relationship_ranking_attribute: str = "rank"
|
||||
|
||||
|
||||
class MapReduceConfig(BaseModel):
|
||||
map_max_tokens: int = 1000
|
||||
map_temperature: float = 0.0
|
||||
reduce_max_tokens: int = 2000
|
||||
reduce_temperature: float = 0.0
|
||||
allow_general_knowledge: bool = False
|
||||
json_mode: bool = False
|
||||
response_type: str = "multiple paragraphs"
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
max_tokens: int = 1500
|
||||
temperature: float = 0.0
|
||||
response_type: str = "multiple paragraphs"
|
||||
233
agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py
Normal file
233
agent_dhal/agentdhal_extensions/tools/graphrag/_global_search.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol import ChatModel
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_communities,
|
||||
read_indexer_entities,
|
||||
read_indexer_reports,
|
||||
)
|
||||
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
|
||||
from graphrag.query.structured_search.global_search.search import GlobalSearch
|
||||
|
||||
from ._config import GlobalContextConfig as ContextConfig
|
||||
from ._config import GlobalDataConfig as DataConfig
|
||||
from ._config import MapReduceConfig
|
||||
|
||||
_default_context_config = ContextConfig()
|
||||
_default_mapreduce_config = MapReduceConfig()
|
||||
|
||||
|
||||
class GlobalSearchToolArgs(BaseModel):
|
||||
query: str = Field(..., description="The user query to perform global search on.")
|
||||
|
||||
|
||||
class GlobalSearchToolReturn(BaseModel):
|
||||
answer: str
|
||||
|
||||
|
||||
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
|
||||
"""Enables running GraphRAG global search queries as an AutoGen tool.
|
||||
|
||||
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
|
||||
The search combines graph-based document relationships with semantic embeddings to find relevant information.
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
|
||||
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
|
||||
|
||||
Before using this tool, you must complete the GraphRAG setup and indexing process:
|
||||
|
||||
1. Follow the GraphRAG documentation to initialize your project and settings
|
||||
2. Configure and tune your prompts for the specific use case
|
||||
3. Run the indexing process to generate the required data files
|
||||
4. Ensure you have the settings.yaml file from the setup process
|
||||
|
||||
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
|
||||
for detailed instructions on completing these prerequisite steps.
|
||||
|
||||
Example usage with AssistantAgent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.tools.graphrag import GlobalSearchTool
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize the OpenAI client
|
||||
openai_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key="<api-key>",
|
||||
)
|
||||
|
||||
# Set up global search tool
|
||||
global_tool = GlobalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
|
||||
|
||||
# Create assistant agent with the global search tool
|
||||
assistant_agent = AssistantAgent(
|
||||
name="search_assistant",
|
||||
tools=[global_tool],
|
||||
model_client=openai_client,
|
||||
system_message=(
|
||||
"You are a tool selector AI assistant using the GraphRAG framework. "
|
||||
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
|
||||
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function."
|
||||
),
|
||||
)
|
||||
|
||||
# Run a sample query
|
||||
query = "What is the overall sentiment of the community reports?"
|
||||
await Console(assistant_agent.run_stream(task=query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
model: ChatModel,
|
||||
data_config: DataConfig,
|
||||
context_config: ContextConfig = _default_context_config,
|
||||
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
|
||||
):
|
||||
super().__init__(
|
||||
args_type=GlobalSearchToolArgs,
|
||||
return_type=GlobalSearchToolReturn,
|
||||
name="global_search_tool",
|
||||
description="Perform a global search with given parameters using graphrag.",
|
||||
)
|
||||
# Use the provided model
|
||||
self._model = model
|
||||
|
||||
# Load parquet files
|
||||
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
|
||||
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
|
||||
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
|
||||
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
|
||||
)
|
||||
|
||||
# Fix: Use correct argument order and types for GraphRAG API
|
||||
communities = read_indexer_communities(community_df, report_df)
|
||||
reports = read_indexer_reports(report_df, community_df, data_config.community_level)
|
||||
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
|
||||
|
||||
context_builder = GlobalCommunityContext(
|
||||
community_reports=reports,
|
||||
communities=communities,
|
||||
entities=entities,
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
context_builder_params = {
|
||||
"use_community_summary": context_config.use_community_summary,
|
||||
"shuffle_data": context_config.shuffle_data,
|
||||
"include_community_rank": context_config.include_community_rank,
|
||||
"min_community_rank": context_config.min_community_rank,
|
||||
"community_rank_name": context_config.community_rank_name,
|
||||
"include_community_weight": context_config.include_community_weight,
|
||||
"community_weight_name": context_config.community_weight_name,
|
||||
"normalize_community_weight": context_config.normalize_community_weight,
|
||||
"max_tokens": context_config.max_data_tokens,
|
||||
"context_name": "Reports",
|
||||
}
|
||||
|
||||
map_llm_params = {
|
||||
"max_tokens": mapreduce_config.map_max_tokens,
|
||||
"temperature": mapreduce_config.map_temperature,
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
|
||||
reduce_llm_params = {
|
||||
"max_tokens": mapreduce_config.reduce_max_tokens,
|
||||
"temperature": mapreduce_config.reduce_temperature,
|
||||
}
|
||||
|
||||
self._search_engine = GlobalSearch(
|
||||
model=self._model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
max_data_tokens=context_config.max_data_tokens,
|
||||
map_llm_params=map_llm_params,
|
||||
reduce_llm_params=reduce_llm_params,
|
||||
allow_general_knowledge=mapreduce_config.allow_general_knowledge,
|
||||
json_mode=mapreduce_config.json_mode,
|
||||
context_builder_params=context_builder_params,
|
||||
concurrent_coroutines=32,
|
||||
response_type=mapreduce_config.response_type,
|
||||
)
|
||||
|
||||
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
|
||||
search_result = await self._search_engine.search(args.query)
|
||||
assert isinstance(search_result.response, str), "Expected response to be a string"
|
||||
return GlobalSearchToolReturn(answer=search_result.response)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, root_dir: str | Path, config_filepath: str | Path | None = None) -> "GlobalSearchTool":
|
||||
"""Create a GlobalSearchTool instance from GraphRAG settings file.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the GraphRAG root directory
|
||||
config_filepath: Path to the GraphRAG settings file (optional)
|
||||
|
||||
Returns:
|
||||
An initialized GlobalSearchTool instance
|
||||
"""
|
||||
# Load GraphRAG config
|
||||
if isinstance(root_dir, str):
|
||||
root_dir = Path(root_dir)
|
||||
if isinstance(config_filepath, str):
|
||||
config_filepath = Path(config_filepath)
|
||||
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
|
||||
|
||||
# Get the language model configuration from the models section
|
||||
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
|
||||
|
||||
if chat_model_config is None:
|
||||
raise ValueError("default_chat_model not found in config.models")
|
||||
|
||||
# Initialize token encoder based on the model being used
|
||||
try:
|
||||
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
|
||||
except KeyError:
|
||||
# Fallback to cl100k_base if model is not recognized by tiktoken
|
||||
token_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# Create the LLM using ModelManager
|
||||
model = ModelManager().get_or_create_chat_model(
|
||||
name="global_search_model",
|
||||
model_type=chat_model_config.type,
|
||||
config=chat_model_config,
|
||||
)
|
||||
|
||||
# Create data config from storage paths
|
||||
data_config = DataConfig(
|
||||
input_dir=str(config.output.base_dir),
|
||||
)
|
||||
|
||||
return cls(
|
||||
token_encoder=token_encoder,
|
||||
model=model,
|
||||
data_config=data_config,
|
||||
context_config=_default_context_config,
|
||||
mapreduce_config=_default_mapreduce_config,
|
||||
)
|
||||
245
agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py
Normal file
245
agent_dhal/agentdhal_extensions/tools/graphrag/_local_search.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# mypy: disable-error-code="no-any-unimported,misc"
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from agentdhal_core import CancellationToken
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_entities,
|
||||
read_indexer_relationships,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
|
||||
from graphrag.query.structured_search.local_search.search import LocalSearch
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
from ._config import LocalContextConfig, SearchConfig
|
||||
from ._config import LocalDataConfig as DataConfig
|
||||
|
||||
_default_context_config = LocalContextConfig()
|
||||
_default_search_config = SearchConfig()
|
||||
|
||||
|
||||
class LocalSearchToolArgs(BaseModel):
|
||||
query: str = Field(..., description="The user query to perform local search on.")
|
||||
|
||||
|
||||
class LocalSearchToolReturn(BaseModel):
|
||||
answer: str = Field(..., description="The answer to the user query.")
|
||||
|
||||
|
||||
class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]):
|
||||
"""Enables running GraphRAG local search queries as an AutoGen tool.
|
||||
|
||||
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework.
|
||||
The search combines local document context with semantic embeddings to find relevant information.
|
||||
|
||||
.. note::
|
||||
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package.
|
||||
To install:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "agentdhal-agentchat" "agentdhal-ext[graphrag]"
|
||||
|
||||
Before using this tool, you must complete the GraphRAG setup and indexing process:
|
||||
|
||||
1. Follow the GraphRAG documentation to initialize your project and settings
|
||||
2. Configure and tune your prompts for the specific use case
|
||||
3. Run the indexing process to generate the required data files
|
||||
4. Ensure you have the settings.yaml file from the setup process
|
||||
|
||||
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/)
|
||||
for detailed instructions on completing these prerequisite steps.
|
||||
|
||||
Example usage with AssistantAgent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_extensions.tools.graphrag import LocalSearchTool
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize the OpenAI client
|
||||
openai_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key="<api-key>",
|
||||
)
|
||||
|
||||
# Set up local search tool
|
||||
local_tool = LocalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))
|
||||
|
||||
# Create assistant agent with the local search tool
|
||||
assistant_agent = AssistantAgent(
|
||||
name="search_assistant",
|
||||
tools=[local_tool],
|
||||
model_client=openai_client,
|
||||
system_message=(
|
||||
"You are a tool selector AI assistant using the GraphRAG framework. "
|
||||
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
|
||||
"For specific, detailed information about particular entities or relationships, call the 'local_search' function."
|
||||
),
|
||||
)
|
||||
|
||||
# Run a sample query
|
||||
query = "What does the station-master say about Dr. Becher?"
|
||||
await Console(assistant_agent.run_stream(task=query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Args:
|
||||
token_encoder (tiktoken.Encoding): The tokenizer used for text encoding
|
||||
model: The chat model to use for search (GraphRAG ChatModel)
|
||||
embedder: The text embedding model to use (GraphRAG EmbeddingModel)
|
||||
data_config (DataConfig): Configuration for data source locations and settings
|
||||
context_config (LocalContextConfig, optional): Configuration for context building. Defaults to default config.
|
||||
search_config (SearchConfig, optional): Configuration for search operations. Defaults to default config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
model: ChatModel, # ChatModel from GraphRAG
|
||||
embedder: EmbeddingModel, # EmbeddingModel from GraphRAG
|
||||
data_config: DataConfig,
|
||||
context_config: LocalContextConfig = _default_context_config,
|
||||
search_config: SearchConfig = _default_search_config,
|
||||
):
|
||||
super().__init__(
|
||||
args_type=LocalSearchToolArgs,
|
||||
return_type=LocalSearchToolReturn,
|
||||
name="local_search_tool",
|
||||
description="Perform a local search with given parameters using graphrag.",
|
||||
)
|
||||
# Use the provided models
|
||||
self._model = model
|
||||
self._embedder = embedder
|
||||
|
||||
# Load parquet files
|
||||
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
|
||||
relationship_df: pd.DataFrame = pd.read_parquet( # type: ignore
|
||||
f"{data_config.input_dir}/{data_config.relationship_table}.parquet"
|
||||
)
|
||||
text_unit_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.text_unit_table}.parquet") # type: ignore
|
||||
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
|
||||
|
||||
# Read data using indexer adapters
|
||||
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
|
||||
relationships = read_indexer_relationships(relationship_df)
|
||||
text_units = read_indexer_text_units(text_unit_df)
|
||||
# Set up vector store for entity embeddings
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name="default-entity-description",
|
||||
)
|
||||
description_embedding_store.connect(db_uri=f"{data_config.input_dir}/lancedb")
|
||||
|
||||
# Set up context builder
|
||||
context_builder = LocalSearchMixedContext(
|
||||
entities=entities,
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
text_embedder=self._embedder,
|
||||
text_units=text_units,
|
||||
relationships=relationships,
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
context_builder_params = {
|
||||
"text_unit_prop": context_config.text_unit_prop,
|
||||
"community_prop": context_config.community_prop,
|
||||
"include_entity_rank": context_config.include_entity_rank,
|
||||
"rank_description": context_config.rank_description,
|
||||
"include_relationship_weight": context_config.include_relationship_weight,
|
||||
"relationship_ranking_attribute": context_config.relationship_ranking_attribute,
|
||||
"max_tokens": context_config.max_data_tokens,
|
||||
}
|
||||
|
||||
llm_params = {
|
||||
"max_tokens": search_config.max_tokens,
|
||||
"temperature": search_config.temperature,
|
||||
}
|
||||
|
||||
self._search_engine = LocalSearch(
|
||||
model=self._model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
response_type=search_config.response_type,
|
||||
context_builder_params=context_builder_params,
|
||||
model_params=llm_params,
|
||||
)
|
||||
|
||||
async def run(self, args: LocalSearchToolArgs, cancellation_token: CancellationToken) -> LocalSearchToolReturn:
|
||||
search_result = await self._search_engine.search(args.query) # type: ignore[reportUnknownMemberType]
|
||||
assert isinstance(search_result.response, str), "Expected response to be a string"
|
||||
return LocalSearchToolReturn(answer=search_result.response)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls, root_dir: Path, config_filepath: Path | None = None) -> "LocalSearchTool":
|
||||
"""Create a LocalSearchTool instance from GraphRAG settings file.
|
||||
|
||||
Args:
|
||||
root_dir: Path to the GraphRAG root directory
|
||||
config_filepath: Path to the GraphRAG settings file (optional)
|
||||
|
||||
Returns:
|
||||
An initialized LocalSearchTool instance
|
||||
"""
|
||||
# Load GraphRAG config
|
||||
config = load_config(root_dir=root_dir, config_filepath=config_filepath)
|
||||
|
||||
# Get the language model configurations from the models section
|
||||
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
|
||||
embedding_model_config = config.models.get(defs.DEFAULT_EMBEDDING_MODEL_ID)
|
||||
|
||||
if chat_model_config is None:
|
||||
raise ValueError("default_chat_model not found in config.models")
|
||||
if embedding_model_config is None:
|
||||
raise ValueError("default_embedding_model not found in config.models")
|
||||
|
||||
# Initialize token encoder based on the model being used
|
||||
try:
|
||||
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
|
||||
except KeyError:
|
||||
# Fallback to cl100k_base if model is not recognized by tiktoken
|
||||
token_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# Create the models using ModelManager
|
||||
model = ModelManager().get_or_create_chat_model(
|
||||
name="local_search_model",
|
||||
model_type=chat_model_config.type,
|
||||
config=chat_model_config,
|
||||
)
|
||||
|
||||
embedder = ModelManager().get_or_create_embedding_model(
|
||||
name="local_search_embedder",
|
||||
model_type=embedding_model_config.type,
|
||||
config=embedding_model_config,
|
||||
)
|
||||
|
||||
# Create data config from storage paths
|
||||
data_config = DataConfig(
|
||||
input_dir=str(config.output.base_dir),
|
||||
)
|
||||
|
||||
return cls(
|
||||
token_encoder=token_encoder,
|
||||
model=model,
|
||||
embedder=embedder,
|
||||
data_config=data_config,
|
||||
context_config=_default_context_config,
|
||||
search_config=_default_search_config,
|
||||
)
|
||||
Reference in New Issue
Block a user