first commit
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from ._video_surfer import VideoSurfer
|
||||
|
||||
__all__ = ["VideoSurfer"]
|
||||
@@ -0,0 +1,172 @@
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
from agentdhal_agentchat.agents import AssistantAgent
|
||||
from agentdhal_core.models import ChatCompletionClient
|
||||
from agentdhal_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .tools import (
|
||||
extract_audio,
|
||||
get_screenshot_at,
|
||||
get_video_length,
|
||||
save_screenshot,
|
||||
transcribe_audio_with_timestamps,
|
||||
transcribe_video_screenshot,
|
||||
)
|
||||
|
||||
|
||||
class VideoSurfer(AssistantAgent):
|
||||
"""
|
||||
VideoSurfer is a specialized agent designed to answer questions about a local video file.
|
||||
|
||||
Installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "agentdhal-ext[video-surfer]"
|
||||
|
||||
This agent utilizes various tools to extract information from the video, such as its length, screenshots at specific timestamps, and audio transcriptions. It processes these elements to provide detailed answers to user queries.
|
||||
|
||||
Available tools:
|
||||
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.extract_audio`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_video_length`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_audio_with_timestamps`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.get_screenshot_at`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.save_screenshot`
|
||||
- :func:`~agentdhal_extensions.agents.video_surfer.tools.transcribe_video_screenshot`
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client used for generating responses.
|
||||
tools (List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional):
|
||||
A list of tools or functions the agent can use. If not provided, defaults to all video tools from the action space.
|
||||
description (str, optional): A brief description of the agent. Defaults to "An agent that can answer questions about a local video.".
|
||||
system_message (str | None, optional): The system message guiding the agent's behavior. Defaults to a predefined message.
|
||||
|
||||
Example usage:
|
||||
|
||||
The following example demonstrates how to create an video surfing agent with
|
||||
a model client and generate a response to a simple query about a local video
|
||||
called video.mp4.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
import asyncio
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.conditions import TextMentionTermination
|
||||
from agentdhal_agentchat.teams import RoundRobinGroupChat
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.agents.video_surfer import VideoSurfer
|
||||
|
||||
async def main() -> None:
|
||||
\"\"\"
|
||||
Main function to run the video agent.
|
||||
\"\"\"
|
||||
# Define an agent
|
||||
video_agent = VideoSurfer(
|
||||
name="VideoSurfer",
|
||||
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
|
||||
)
|
||||
|
||||
# Define termination condition
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
|
||||
# Define a team
|
||||
agent_team = RoundRobinGroupChat([video_agent], termination_condition=termination)
|
||||
|
||||
# Run the team and stream messages to the console
|
||||
stream = agent_team.run_stream(task="How does Adam define complex tasks in video.mp4? What concrete example of complex does his use? Can you save this example to disk as well?")
|
||||
await Console(stream)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
The following example demonstrates how to create and use a VideoSurfer and UserProxyAgent with MagenticOneGroupChat.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from agentdhal_agentchat.ui import Console
|
||||
from agentdhal_agentchat.teams import MagenticOneGroupChat
|
||||
from agentdhal_agentchat.agents import UserProxyAgent
|
||||
from agentdhal_extensions.models.openai import OpenAIChatCompletionClient
|
||||
from agentdhal_extensions.agents.video_surfer import VideoSurfer
|
||||
|
||||
async def main() -> None:
|
||||
\"\"\"
|
||||
Main function to run the video agent.
|
||||
\"\"\"
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-08-06")
|
||||
|
||||
# Define an agent
|
||||
video_agent = VideoSurfer(
|
||||
name="VideoSurfer",
|
||||
model_client=model_client
|
||||
)
|
||||
|
||||
web_surfer_agent = UserProxyAgent(
|
||||
name="User"
|
||||
)
|
||||
|
||||
# Define a team
|
||||
agent_team = MagenticOneGroupChat([web_surfer_agent, video_agent], model_client=model_client,)
|
||||
|
||||
# Run the team and stream messages to the console
|
||||
stream = agent_team.run_stream(task="Find a latest video about magentic one on youtube and extract quotes from it that make sense.")
|
||||
await Console(stream)
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
DEFAULT_DESCRIPTION = "An agent that can answer questions about a local video."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = """
|
||||
You are a helpful agent that is an expert at answering questions from a video.
|
||||
When asked to answer a question about a video, you should:
|
||||
1. Check if that video is available locally.
|
||||
2. Use the transcription to find which part of the video the question is referring to.
|
||||
3. Optionally use screenshots from those timestamps
|
||||
4. Provide a detailed answer to the question.
|
||||
Reply with TERMINATE when the task has been completed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
tools: List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||
description: Optional[str] = None,
|
||||
system_message: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
model_client=model_client,
|
||||
tools=tools
|
||||
or [
|
||||
get_video_length,
|
||||
get_screenshot_at,
|
||||
save_screenshot,
|
||||
self.vs_transribe_video_screenshot,
|
||||
extract_audio,
|
||||
transcribe_audio_with_timestamps,
|
||||
],
|
||||
description=description or self.DEFAULT_DESCRIPTION,
|
||||
system_message=system_message or self.DEFAULT_SYSTEM_MESSAGE,
|
||||
)
|
||||
|
||||
async def vs_transribe_video_screenshot(self, video_path: str, timestamp: float) -> str:
|
||||
"""
|
||||
Transcribes the video screenshot at a specific timestamp.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to the video file.
|
||||
timestamp (float): Timestamp to take the screenshot.
|
||||
|
||||
Returns:
|
||||
str: Transcription of the video screenshot.
|
||||
"""
|
||||
return await transcribe_video_screenshot(video_path, timestamp, self._model_client)
|
||||
156
agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py
Normal file
156
agent_dhal/agentdhal_extensions/agents/video_surfer/tools.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import base64
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import cv2
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import whisper
|
||||
from agentdhal_core import Image as AGImage
|
||||
from agentdhal_core.models import (
|
||||
ChatCompletionClient,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
def extract_audio(video_path: str, audio_output_path: str) -> str:
|
||||
"""
|
||||
Extracts audio from a video file and saves it as an MP3 file.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param audio_output_path: Path to save the extracted audio file.
|
||||
:return: Confirmation message with the path to the saved audio file.
|
||||
"""
|
||||
(ffmpeg.input(video_path).output(audio_output_path, format="mp3").run(quiet=True, overwrite_output=True)) # type: ignore
|
||||
return f"Audio extracted and saved to {audio_output_path}."
|
||||
|
||||
|
||||
def transcribe_audio_with_timestamps(audio_path: str) -> str:
|
||||
"""
|
||||
Transcribes the audio file with timestamps using the Whisper model.
|
||||
|
||||
:param audio_path: Path to the audio file.
|
||||
:return: Transcription with timestamps.
|
||||
"""
|
||||
model = whisper.load_model("base") # type: ignore
|
||||
result: Dict[str, Any] = model.transcribe(audio_path, task="transcribe", language="en", verbose=False) # type: ignore
|
||||
|
||||
segments: List[Dict[str, Any]] = result["segments"]
|
||||
transcription_with_timestamps = ""
|
||||
|
||||
for segment in segments:
|
||||
start: float = segment["start"]
|
||||
end: float = segment["end"]
|
||||
text: str = segment["text"]
|
||||
transcription_with_timestamps += f"[{start:.2f} - {end:.2f}] {text}\n"
|
||||
|
||||
return transcription_with_timestamps
|
||||
|
||||
|
||||
def get_video_length(video_path: str) -> str:
|
||||
"""
|
||||
Returns the length of the video in seconds.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:return: Duration of the video in seconds.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
duration = frame_count / fps
|
||||
cap.release()
|
||||
|
||||
return f"The video is {duration:.2f} seconds long."
|
||||
|
||||
|
||||
def save_screenshot(video_path: str, timestamp: float, output_path: str) -> None:
|
||||
"""
|
||||
Captures a screenshot at the specified timestamp and saves it to the output path.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamp: Timestamp in seconds.
|
||||
:param output_path: Path to save the screenshot. The file format is determined by the extension in the path.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_number = int(timestamp * fps)
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
cv2.imwrite(output_path, frame)
|
||||
else:
|
||||
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
|
||||
cap.release()
|
||||
|
||||
|
||||
async def transcribe_video_screenshot(video_path: str, timestamp: float, model_client: ChatCompletionClient) -> str:
|
||||
"""
|
||||
Transcribes the content of a video screenshot captured at the specified timestamp using OpenAI API.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamp: Timestamp in seconds.
|
||||
:param model_client: ChatCompletionClient instance.
|
||||
:return: Description of the screenshot content.
|
||||
"""
|
||||
screenshots = get_screenshot_at(video_path, [timestamp])
|
||||
if not screenshots:
|
||||
return "Failed to capture screenshot."
|
||||
|
||||
_, frame = screenshots[0]
|
||||
# Convert the frame to bytes and then to base64 encoding
|
||||
_, buffer = cv2.imencode(".jpg", frame)
|
||||
frame_bytes = buffer.tobytes()
|
||||
frame_base64 = base64.b64encode(frame_bytes).decode("utf-8")
|
||||
screenshot_uri = f"data:image/jpeg;base64,{frame_base64}"
|
||||
|
||||
messages = [
|
||||
UserMessage(
|
||||
content=[
|
||||
"Following is a screenshot from the video at {} seconds. Describe what you see here.",
|
||||
AGImage.from_uri(screenshot_uri),
|
||||
],
|
||||
source="tool",
|
||||
)
|
||||
]
|
||||
|
||||
result = await model_client.create(messages=messages)
|
||||
return str(result.content)
|
||||
|
||||
|
||||
def get_screenshot_at(video_path: str, timestamps: List[float]) -> List[Tuple[float, np.ndarray[Any, Any]]]:
|
||||
"""
|
||||
Captures screenshots at the specified timestamps and returns them as Python objects.
|
||||
|
||||
:param video_path: Path to the video file.
|
||||
:param timestamps: List of timestamps in seconds.
|
||||
:return: List of tuples containing timestamp and the corresponding frame (image).
|
||||
Each frame is a NumPy array (height x width x channels).
|
||||
"""
|
||||
screenshots: List[Tuple[float, np.ndarray[Any, Any]]] = []
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video file {video_path}")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
duration = total_frames / fps
|
||||
|
||||
for timestamp in timestamps:
|
||||
if 0 <= timestamp <= duration:
|
||||
frame_number = int(timestamp * fps)
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
# Append the timestamp and frame to the list
|
||||
screenshots.append((timestamp, frame))
|
||||
else:
|
||||
raise IOError(f"Failed to capture frame at {timestamp:.2f}s")
|
||||
else:
|
||||
raise ValueError(f"Timestamp {timestamp:.2f}s is out of range [0s, {duration:.2f}s]")
|
||||
|
||||
cap.release()
|
||||
return screenshots
|
||||
Reference in New Issue
Block a user