diff --git a/agent/data_agent.py b/agent/data_agent.py new file mode 100644 index 0000000..e0d68e9 --- /dev/null +++ b/agent/data_agent.py @@ -0,0 +1,113 @@ +from typing import TypedDict, Annotated, Sequence, List, Optional +import operator +from langchain_openai import AzureChatOpenAI +from langgraph.graph import StateGraph, END +from langgraph.checkpoint.memory import MemorySaver + +from tools.scraping import gen_queries, get_video_ids, download, VideoInfo +from tools.video_chunking import detect_segments, SegmentInfo +from tools.annotating import extract_clues, gen_annotations + +from tools.prompts import ( + GEN_QUERIES_PROMPT, + EXTRACT_CLUES_PROMPT, + GEN_ANNOTATIONS_PROMPT, +) + + +llm = AzureChatOpenAI( + temperature=0.0, + azure_deployment="gpt4o", + openai_api_version="2023-07-01-preview", +) + +memory = MemorySaver() +# memory = SqliteSaver.from_conn_string(":memory:") + + +class AgentState(TypedDict): + task: str + search_queries: List[str] + video_ids: List[str] + video_infos: List[VideoInfo] + clip_text_prompts: List[str] + segment_infos: List[SegmentInfo] + clues: List[str] + annotations: List[str] + + +class DataAgent: + def __init__(self, llm, memory): + self.llm = llm + self.memory = memory + self.graph = self.build_graph() + + def build_graph(self): + builder = StateGraph(AgentState) + + builder.add_node("generate_queries", self.gen_queries_node) + builder.add_node("get_video_ids", self.get_video_ids_node) + builder.add_node("download", self.download_node) + builder.add_node("detect_segments", self.detect_segments_node) + builder.add_node("extract_clues", self.extract_clues_node) + builder.add_node("gen_annotations", self.gen_annotations_node) + + builder.set_entry_point("generate_queries") + + builder.add_edge("generate_queries", "get_video_ids") + builder.add_edge("get_video_ids", "download") + builder.add_edge("download", "detect_segments") + builder.add_edge("detect_segments", "extract_clues") + builder.add_edge("extract_clues", "gen_annotations") + builder.add_edge("gen_annotations", END) + + graph = builder.compile(checkpointer=memory) + + return graph + + def gen_queries_node(self, state: AgentState): + search_queries = gen_queries(self.llm, state["task"], GEN_QUERIES_PROMPT) + return {"search_queries": search_queries[:2]} + + def get_video_ids_node(self, state: AgentState): + video_ids = get_video_ids(state["search_queries"]) + return {"video_ids": video_ids} + + def download_node(self, state: AgentState): + video_infos = download(state["video_ids"]) + return {"video_infos": video_infos} + + def detect_segments_node(self, state: AgentState): + segment_infos = detect_segments( + state["video_infos"], state["clip_text_prompts"] + ) + return {"segment_infos": segment_infos} + + def extract_clues_node(self, state: AgentState): + clues = extract_clues( + self.llm, + EXTRACT_CLUES_PROMPT, + state["segment_infos"], + state["video_infos"], + ) + return {"clues": clues} + + def gen_annotations_node(self, state: AgentState): + annotations = gen_annotations(self.llm, GEN_ANNOTATIONS_PROMPT, state["clues"]) + return {"annotations": annotations} + + def run(self, task: str, thread_id: str): + thread = {"configurable": {"thread_id": thread_id}} + for step in self.graph.stream( + { + "task": task, + "clip_text_prompts": ["person doing squats"], + }, + thread, + ): + if "download" in step: + print("dowload happened") + elif "extract_clues" in step: + print("extract_clues happened") + else: + print(step) diff --git a/agent/run_agent.ipynb b/agent/run_agent.ipynb new file mode 100644 index 0000000..f6252ad --- /dev/null +++ b/agent/run_agent.ipynb @@ -0,0 +1,98 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "_ = load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "torch.device(\"cuda:0\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from data_agent import DataAgent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_openai import AzureChatOpenAI\n", + "from langgraph.checkpoint.memory import MemorySaver\n", + "\n", + "\n", + "llm = AzureChatOpenAI(\n", + " temperature=0.0,\n", + " azure_deployment=\"gpt4o\",\n", + " openai_api_version=\"2023-07-01-preview\",\n", + ")\n", + "\n", + "memory = MemorySaver()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent = DataAgent(llm, memory)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\"i wanna teach people how to do squats\", thread_id=\"1\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "vlm_databuilder_agent", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/agent/tools/annotating.py b/agent/tools/annotating.py new file mode 100644 index 0000000..d38b547 --- /dev/null +++ b/agent/tools/annotating.py @@ -0,0 +1,181 @@ +from typing import List, Optional +from collections import defaultdict +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.prompts import ChatPromptTemplate + +# 4. Create nodes + +from .scraping import VideoInfo +from .video_chunking import SegmentInfo + + +class LocalClue(BaseModel): + """Local clues for a segment""" + + id: str = Field(description="LC1,LC2...") + quote: str = Field( + description="the quote from the transcript that was used to create this clue." + ) + quote_timestamp_start: str = Field( + description="the exact start timestamp of the quote." + ) + quote_timestamp_end: str = Field( + description="the exact end timestamp of the quote." + ) + clue: str = Field(description="the main clue data") + + +class GlobalClue(BaseModel): + """Global clues for a segment""" + + id: str = Field(description="GC1,GC2...") + quote: str = Field( + description="the quote from the transcript that was used to create this clue." + ) + quote_timestamp_start: str = Field( + description="the exact start timestamp of the quote." + ) + quote_timestamp_end: str = Field( + description="the exact end timestamp of the quote." + ) + clue: str = Field(description="the main clue data.") + relevance_to_segment: str = Field( + description="why do you think this global clue is relevant to the segment you are working with right now." + ) + + +class LogicalInference(BaseModel): + """Logical inferences for a segment""" + + id: str = Field(description="LI1,LI2,...") + description: str = Field(description="A concise form of the logical inference.") + details: str = Field( + description="A verbose explanation of what insight about what happens in this segment should be made based on the clues that you found." + ) + + +class SegmentAnnotation(BaseModel): + local_clues: list[LocalClue] = Field( + description="Local clues are inside the segment in terms of timestamps." + ) + global_clues: list[GlobalClue] = Field( + description="Global clues are scattered across the entire transcript." + ) + logical_inferences: list[LogicalInference] = Field( + description="What can we infer about the topic, that the user is looking for in the video, can we make based on the clues inside this segment" + ) + + +class SegmentWithClueInfo(BaseModel): + """ + Annotation for a video segment. + """ + + start_timestamp: str = Field( + description="start timestamp of the segment in format HH:MM:SS.MS" + ) + end_timestamp: str = Field( + description="start timestamp of the segment in format HH:MM:SS.MS" + ) + segment_annotation: SegmentAnnotation = Field( + description="list of annotations for the segment" + ) + + +class VideoAnnotation(BaseModel): + """ + Segments of a video. + """ + + segments: list[SegmentWithClueInfo] = Field( + description="information about each segment" + ) + + +def extract_clues( + llm, + system_prompt: str, + segment_infos: List[SegmentInfo], + video_infos: List[VideoInfo], +): + + prompt_template = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ( + "user", + "Segment timecodes: {{ segment_timecodes }}\nTranscript: {{ transcript }}", + ), + ], + template_format="jinja2", + ) + + model = prompt_template | llm.with_structured_output(VideoAnnotation) + + segment_infos_dict = defaultdict(list) + for segment_info in segment_infos: + segment_infos_dict[segment_info.video_id].append(segment_info) + + video_infos_dict = {video_info.video_id: video_info for video_info in video_infos} + + clues = [] + + for video_id, segment_infos in segment_infos_dict.items(): + transcript = video_infos_dict[video_id].transcript + segment_infos_chunks = [ + segment_infos[i : i + 5] for i in range(0, len(segment_infos), 5) + ] + + for chunk in segment_infos_chunks: + video_annotation: VideoAnnotation = model.invoke( + { + "segment_timecodes": "\n".join( + [f"{s.start_timestamp}-{s.end_timestamp}" for s in chunk] + ), + "transcript": transcript, + } + ) + clues.extend(video_annotation.segments) + + return clues + + +def gen_annotations(llm, system_prompt: str, clues: List[SegmentAnnotation]): + class SegmentFeedback(BaseModel): + right: Optional[str] = Field(description="what was right in the performance") + wrong: Optional[str] = Field(description="what was wrong in the performance") + correction: Optional[str] = Field( + description="how and in what ways it the performance could be improved" + ) + + # The segment timestamps are taken from the provided information. + class SegmentCompleteAnnotation(BaseModel): + squats_probability: Optional[str] = Field( + description="how high is the probability that the person is doing squats in the segment: low, medium, high, unknown(null)" + ) + squats_technique_correctness: Optional[str] = Field( + description="correctness of the squat technique." + ) + squats_feedback: Optional[SegmentFeedback] = Field( + description="what was right and wrong in the squat perfomance in the segment. When the technique is incorrect, provide instructions how to correct them." + ) + + prompt_template = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ("user", "Clues: {{ clues }}"), + ], + template_format="jinja2", + ) + + model = prompt_template | llm.with_structured_output(SegmentCompleteAnnotation) + + annotations = [] + for clue in clues: + segment_annotation: SegmentCompleteAnnotation = model.invoke( + {"clues": clue.json()} + ) + + annotations.append(segment_annotation.json()) + + return annotations diff --git a/agent/tools/prompts.py b/agent/tools/prompts.py new file mode 100644 index 0000000..bc99503 --- /dev/null +++ b/agent/tools/prompts.py @@ -0,0 +1,217 @@ +# 3. Set prompts + +GEN_QUERIES_PROMPT = ( + "You a helping the user to find a very large and diverse set of videos on a video hosting service.", + "A user will only describe which videos they are looking for and how many queries they need.", +) + +# prompt='I want to find instructional videos about how to do squats.', +# num_queries_prompt = f'I need {num_queries} queries' + +EXTRACT_CLUES_PROMPT = """You are a highly intelligent data investigator. +You take unstructured damaged data and look for clues that could help restore the initial information +and extract important insights from it. +You are the best one for this job in the world because you are a former detective. +You care about even the smallest details, and your guesses about what happened in the initial file +even at very limited inputs are usually absolutely right. +You use deductive and inductive reasoning at the highest possible quality. + +#YOUR TODAY'S JOB +The user needs to learn about what happens in a specific segment of a video file. Your job is to help the user by providing clues that would help the user make the right assumption. +The user will provide you with: +1. Instructions about what kind of information the user is trying to obtain. +2. A list of time codes of the segments in format "-". All the provided segment of the video contain what the user is looking for, but other parts of the video might have different content. +3. A transcript of the *full video* in format of "\\n" + +Your task: +1. Read the transcript. +2. Provide the clues in a given format. +3. Provied any other info requested by the user. + +#RULES +!!! VERY IMPORTANT !!! +1. Rely only on the data provided in the transcript. Do not improvise. All the quotes and corresponding timestamps must be taken from the transcript. Quote timestamps must be taken directly from the transcript. +2. Your job is to find the data already provided in the transcript. +3. Analyze every segment. Only skip a segment if there is no information about it in the trascript. +4. For local clues, make sure that the quotes that you provide are located inside the segment. To do this, double check the timestamps from the transcript and the segment. +5. For all clues, make sure that the quotes exactly correspond to the timestamps that you provide. +6. When making clues, try as much as possible to make them describe specifically what is shown in the segment. +7. Follow the format output. +8. Be very careful with details. Don't generalize. Always double check your results. + +Please, help the user find relevant clues to reconstruct the information they are looking for, for each provided segment. + +WHAT IS A CLUE: A *clue*, in the context of reconstructing narratives from damaged data, +is a fragment of information extracted from a corrupted or incomplete source that provides +insight into the original content. These fragments serve as starting points for inference +and deduction, allowing researchers to hypothesize about the fuller context or meaning of +the degraded material. The process of identifying and interpreting clues involves both objective analysis of the +available data and subjective extrapolation based on domain knowledge, contextual understanding, +and logical reasoning. + +Here is what the user expects to have from you: +1. *Local clues* that would help the user undestand how the thing they are looking for happens inside the segment. Local clues for a segment are generated from quotes inside a specific segment. +2. *Global clues* that would help the user understand how the thing they are looking for happens inside the segment. Global clues for a segment are generated from quotes all around the video, but are very relevant to the specific that they are provided for. +3. *Logical inferences* that could help the user understand how the thing they are looking for happens inside the segment. Logical inferences for a segment are deducted from local and global clues for this segment. + +!!!IT IS EXTREMELY IMPORTANT TO DELIVER ALL THREE THINGS!!! + + Good local clues examples: [ + { + "id": "LC1", + "timestamp": "00:00:19", + "quote": "exercises do them wrong and instead of", + "clue": "This phrase introduces the concept of incorrect exercise form, setting the stage for a demonstration of improper technique." + }, + { + "id": "LC2", + "timestamp": "00:00:21", + "quote": "growing nice quads and glutes you'll", + "clue": "Mentions the expected benefits of proper squats (muscle growth), implying that these benefits won't be achieved with incorrect form." + }, + { + "id": "LC3", + "timestamp": "00:00:22", + "quote": "feel aches and pains in your knees your", + "clue": "Directly states negative consequences of improper form, strongly suggesting that this segment demonstrates incorrect technique." + }, + { + "id": "LC4", + "timestamp": "00:00:24", + "quote": "lower back and even your shoulders", + "clue": "Continuation of LC3, emphasizing multiple areas of potential pain from improper form." + }, + { + "id": "LC5", + "timestamp": "00:00:26", + "quote": "let's see how to do it correctly", + "clue": "This phrase suggests a transition is about to occur. The incorrect form has been shown, and correct form will follow." + } + ] + + Double check that the timestamp and the quote that you provide exactly correspond to what you found in the transcript. + For example, if the transcript says: + "00:05:02 + he took the glasses + 00:05:04 + and gave them to me" + Then a GOOD output will be: + - timestamp: 00:05:03 + - quote: "he took the glasses and gave them to me" + And a BAD output would be: + - timestamp: 00:04:02 + - quote: "he gave me the glasses" + + Good global clues examples: [ + { + "id": "GC1", + "timestamp": "00:01:15", + "quote": "Before we dive into specific techniques, let's talk about safety.", + "clue": "Introduces the theme of safety in squatting.", + "relevance_to_segment": "This earlier emphasis on safety provides context for why proper depth is important and why it's being addressed in our segment. It connects to the fear of knee pain mentioned in LC3." + }, + { + "id": "GC2", + "timestamp": "00:02:30", + "quote": "Squatting is a fundamental movement pattern in everyday life.", + "clue": "Emphasizes the importance of squats beyond just exercise.", + "relevance_to_segment": "This broader context heightens the importance of learning proper squat depth as demonstrated in our segment. It suggests that the techniques shown have applications beyond just gym workouts." + }, + { + "clue_id": "GC3", + "timestamp": "00:05:20", + "quote": "If you have existing knee issues, consult a physician before attempting deep squats.", + "clue": "Provides a health disclaimer related to squat depth.", + "relevance_to_segment": "While this comes after our segment, it's relevant because it addresses the concern about knee pain mentioned in LC3. It suggests that the demonstration in our segment is generally safe but acknowledges individual variations." + }, + { + "clue_id": "GC4", + "timestamp": "00:06:45", + "quote": "Proper depth ensures full engagement of your quadriceps and glutes.", + "clue": "Explains the benefit of correct squat depth.", + "relevance_to_segment": "This later explanation provides justification for the depth guideline given in LC4. It helps viewers understand why the demonstrated technique is important." + }, + { + "clue_id": "GC5", + "timestamp": "00:00:30", + "quote": "Today, we'll cover squat variations for beginners to advanced lifters.", + "clue": "Outlines the scope of the entire video.", + "relevance_to_segment": "This early statement suggests that our segment, focusing on proper depth, is part of a comprehensive guide. It implies that the demonstration might be adaptable for different skill levels." + } + ] + Double check that the timestamp and the quote that you provide exactly correspond to what you found in the transcript. + For example, if the transcript says: + "00:05:02 + he took the glasses + 00:05:04 + and gave them to me" + Then a GOOD output will be: + - timestamp: 00:05:03 + - quote: "he took the glasses and gave them to me" + And a BAD output would be: + - timestamp: 00:04:02 + - quote: "he gave me the glasses" + + + Good logical inference examples: + [ + { + "id": "LI1", + "description": "Primary Demonstration of Heel Lift", + "details": "Given that GC1-GC3 describe the 'most common mistake' as heels lifting off the ground, and this description immediately precedes our segment, it's highly probable that this is the primary error being demonstrated. This is further supported by the segment's focus on incorrect form (LC1-LC4)." + }, + { + "id": "LI2", + "description": "Multiple Error Demonstration", + "details": "While heel lift is likely the primary focus, the mention of multiple pain points (knees, lower back, shoulders in LC3-LC4) suggests that the demonstrator may be exhibiting several forms of incorrect technique simultaneously. This comprehensive 'what not to do' approach would be pedagogically effective." + }, + { + "id": "LI3", + "description": "Possible Inclusion of 'Butt Wink'", + "details": "Although 'butt wink' is mentioned after our segment (GC4-GC6), its connection to back pain (which is mentioned in LC4) raises the possibility that this error is also present in the demonstration. The instructor may be showing multiple errors early on, then breaking them down individually later." + }, + { + "id": "LI4", + "description": "Segment Placement in Overall Video Structure", + "details": "The segment's position (starting at 00:00:19) and the phrase 'let's see how to do it correctly' (LC5) at the end suggest this is an early, foundational part of the video. It likely serves to grab attention by showing common mistakes before transitioning to proper form instruction." + }, + { + "id": "LI5", + "description": "Intentional Exaggeration of Errors", + "details": "Given the educational nature of the video, it's plausible that the demonstrator is intentionally exaggerating the incorrect form. This would make the errors more obvious to viewers and enhance the contrast with correct form shown later." + } + ] +""" + + +GEN_ANNOTATIONS_PROMPT = """You are a helpful assistant that performs high quality data investigation and transformation. + You will be given a JSON object with clues and other helpful information about what's going on + in a specific part of a video file. This part is called a segment. Your job is to: + 1. Read this JSON object carefully + 2. Answer user's questions about this segment + 3. Provide the answer as a JSON object in a schema provided by the user + Important rules: + 1. You can only rely on data presented in a provided JSON object. Don't improvise. + 2. Follow user's request carefully. + 3. Don't rush to deliver the answer. Take some time to think. Make a deep breath. Then start writing. + 4. If you want to output field as empty (null), output it as JSON null (without quotes), not as a string "null". +—> GOOD EXAMPLES: + "wrong":"Knees caving in: This can stress the knees and reduce effectiveness" + "correction":"Focus on keeping knees aligned with your toes." + "wrong":"Rounding the back: This increases the risk of back injuries" + "correction":"Keep your chest up and maintain a neutral spine throughout the movement." + "wrong":"Heels are lifting off the ground: this shifts the weight forward, reducing stability" + "correction":" Keep your weight on your heels and press through them as you rise." + "right":"Chest and shoulders: The chest is up, and the shoulders are back, maintaining an upright torso." + "correction":null +—> BAD EXAMPLES: + "wrong":"knees" + "correction":"fix knees" + "wrong":"back looks funny" + "correction":"make back better" + "wrong":"feet are doing something" + "correction":"feet should be different" + "right":"arms" + "correction":"arms are fine i think" +—> BAD EXAMPLES END HERE +""" diff --git a/agent/tools/scraping.py b/agent/tools/scraping.py new file mode 100644 index 0000000..79d3c61 --- /dev/null +++ b/agent/tools/scraping.py @@ -0,0 +1,150 @@ +from typing import List + +import scrapetube +import yt_dlp +from datetime import datetime +from pathlib import Path +from .sub_utils import vtt_to_txt + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain.pydantic_v1 import BaseModel, Field + + +class VideoInfo(BaseModel): + video_id: str + url: str + relative_video_path: str + subs: str + transcript: str + + +def gen_queries(llm, task: str, system_prompt: str) -> List[str]: + class QueryList(BaseModel): + """A list of queries to find videos on a video hosting service""" + + search_queries: list[str] = Field(default=None, description="a list of queries") + + messages = [ + SystemMessage(content=str(system_prompt)), + HumanMessage(content=task), + ] + + model = llm.with_structured_output(QueryList) + response: QueryList = model.invoke(messages) + + return response.search_queries + + +def get_video_ids(queries: List[str]) -> List[str]: + videos_per_query = 1 + sleep = 0 + sort_by = "relevance" + results_type = "video" + only_creative_commons = False + + video_ids = set() + for query in queries: + for video in scrapetube.get_search( + query=query, + limit=videos_per_query, + sleep=sleep, + sort_by=sort_by, + results_type=results_type, + ): + video_ids.add(video["videoId"]) + video_ids = list(video_ids) + + if only_creative_commons: + video_ids_cc = [] + for i in video_ids: + YDL_OPTIONS = { + "quiet": True, + "simulate": True, + "forceurl": True, + } + with yt_dlp.YoutubeDL(YDL_OPTIONS) as ydl: + info = ydl.extract_info(f"youtube.com/watch?v={i}", download=False) + if "creative commons" in info.get("license", "").lower(): + video_ids_cc.append(i) + video_ids = video_ids_cc + + return video_ids + + +def download(video_ids: List[str]) -> List[VideoInfo]: + + LOCAL_ROOT = Path("./tmp/agent_squats").resolve() + video_dir = LOCAL_ROOT / "videos" + sub_dir = LOCAL_ROOT / "subs" + + discard_path = LOCAL_ROOT / "videos_without_subs" + discard_path.mkdir(parents=True, exist_ok=True) + + downloaded_video_ids = [video_path.stem for video_path in video_dir.glob("*.mp4")] + downloaded_video_ids += [ + video_path.stem for video_path in discard_path.glob("*.mp4") + ] + + print(f"Downloaded video ids: {downloaded_video_ids}") + + only_with_transcripts = True + + YDL_OPTIONS = { + "writeautomaticsub": True, + "subtitleslangs": ["en"], + "subtitlesformat": "vtt", + "overwrites": False, + "format": "mp4", + "outtmpl": { + "default": video_dir.as_posix() + "/%(id)s.%(ext)s", + "subtitle": sub_dir.as_posix() + "/%(id)s.%(ext)s", + }, + } + + video_infos = [] + + with yt_dlp.YoutubeDL(YDL_OPTIONS) as ydl: + for video_id in video_ids: + url = f"https://www.youtube.com/watch?v={video_id}" + + if video_id not in downloaded_video_ids: + try: + ydl.download(url) + except Exception as e: + print(datetime.now(), f"Error at video {video_id}, skipping") + print(datetime.now(), e) + continue + + video_path = Path(ydl.prepare_filename({"id": video_id, "ext": "mp4"})) + sub_path = Path( + ydl.prepare_filename( + {"id": video_id, "ext": "en.vtt"}, dir_type="subtitle" + ) + ) + + with sub_path.open("r") as f: + subs = f.read() + + transcript = vtt_to_txt(sub_path) + + video_info = VideoInfo( + video_id=video_id, + url=url, + relative_video_path=video_path.relative_to(LOCAL_ROOT).as_posix(), + subs=subs, + transcript=transcript, + ) + + video_infos.append(video_info) + + if only_with_transcripts: + filtered_video_infos = [] + for video_info in video_infos: + if video_info.transcript: + filtered_video_infos.append(video_info) + else: + video_path = LOCAL_ROOT / video_info.video_path + video_path.rename(discard_path / video_path.name) + video_infos = filtered_video_infos + + return video_infos diff --git a/agent/tools/sub_utils.py b/agent/tools/sub_utils.py new file mode 100644 index 0000000..4a59802 --- /dev/null +++ b/agent/tools/sub_utils.py @@ -0,0 +1,90 @@ + +# https://gist.github.com/glasslion/b2fcad16bc8a9630dbd7a945ab5ebf5e + + +# import sys +import re + +def remove_tags(text): + """ + Remove vtt markup tags + """ + tags = [ + r'', + r'', + r'<\d{2}:\d{2}:\d{2}\.\d{3}>', + + ] + + for pat in tags: + text = re.sub(pat, '', text) + + # extract timestamp, only kep HH:MM + text = re.sub( + r'(\d{2}:\d{2}:\d{2})\.\d{3} --> .* align:start position:0%', + r'\g<1>', + text + ) + + text = re.sub(r'^\s+$', '', text, flags=re.MULTILINE) + return text + +def remove_header(lines): + """ + Remove vtt file header + """ + pos = -1 + for mark in ('##', 'Language: en',): + if mark in lines: + pos = lines.index(mark) + lines = lines[pos+1:] + return lines + + +def merge_duplicates(lines): + """ + Remove duplicated subtitles. Duplacates are always adjacent. + """ + last_timestamp = '' + last_cap = '' + for line in lines: + if line == "": + continue + if re.match('^\d{2}:\d{2}:\d{2}$', line): + if line != last_timestamp: + yield line + last_timestamp = line + else: + if line != last_cap: + yield line + last_cap = line + + +def merge_short_lines(lines): + buffer = '' + for line in lines: + if line == "" or re.match('^\d{2}:\d{2}$', line): + yield '\n' + line + continue + + if len(line+buffer) < 80: + buffer += ' ' + line + else: + yield buffer.strip() + buffer = line + yield buffer + +def vtt_to_txt(vtt_file_name, as_list=True): + # txt_name = re.sub(r'.vtt$', '.txt', vtt_file_name) + with open(vtt_file_name) as f: + text = f.read() + text = remove_tags(text) + lines = text.splitlines() + lines = remove_header(lines) + lines = merge_duplicates(lines) + lines = list(lines) + # lines = merge_short_lines(lines) + # lines = list(lines) + lines = '\n'.join(lines) + + return lines \ No newline at end of file diff --git a/agent/tools/video_chunking.py b/agent/tools/video_chunking.py new file mode 100644 index 0000000..3d41dac --- /dev/null +++ b/agent/tools/video_chunking.py @@ -0,0 +1,190 @@ +import decord +import time +from pathlib import Path +from collections import defaultdict +import torch +from transformers import AutoModel, AutoProcessor +import pandas as pd +from tsmoothie.smoother import LowessSmoother + +from typing import List + +from langchain.pydantic_v1 import BaseModel, Field + +# decord.bridge.set_bridge("torch") + +from .scraping import VideoInfo + + +DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + + +class SegmentInfo(BaseModel): + start_timestamp: str + end_timestamp: str + fps: float + video_id: str + + +class VideoInferenceDataset(torch.utils.data.IterableDataset): + def __init__(self, video_infos: List[VideoInfo], local_root: Path): + super(VideoInferenceDataset).__init__() + + self.video_infos = video_infos + self.local_root = local_root + self.frame_generator = self.get_frame_generator(video_infos, local_root) + + @staticmethod + def get_frame_generator(video_infos, local_root: Path): + + for video_idx, video_info in enumerate(video_infos): + video_path = local_root.joinpath(video_info.relative_video_path) + vr = decord.VideoReader(str(video_path)) + num_frames = len(vr) + fps = vr.get_avg_fps() + frame_indices = range(0, num_frames, round(fps)) + + for frame_idx in frame_indices: + # print(f"Frame idx {frame_idx}") + frame = vr[frame_idx].asnumpy() + yield { + "frame": frame, + "frame_idx": frame_idx, + "video_id": video_idx, + } + + def __next__(self): + return next(self.frame_generator) + + def __iter__(self): + return self + + +def get_segments(data, max_gap=3, min_prob=0.1, min_segment=5): + segments = [] + cur_segment_start = None + not_doing = 0 + for i, p in enumerate(data): + if p >= min_prob and cur_segment_start is None: + cur_segment_start = i + elif cur_segment_start is not None and p < min_prob: + if not_doing >= max_gap: + if i - not_doing - cur_segment_start >= min_segment: + segments.append((cur_segment_start, i - not_doing)) + not_doing = 0 + cur_segment_start = None + else: + not_doing += 1 + elif p >= min_prob: + not_doing = 0 + if ( + cur_segment_start is not None + and (i - not_doing - cur_segment_start) >= min_segment + ): + segments.append((cur_segment_start, i - not_doing)) + + return segments + + +def detect_segments( + video_infos: List[VideoInfo], clip_text_prompts: List[str] +) -> List[SegmentInfo]: + + LOCAL_ROOT = Path("./tmp/agent_squats").resolve() + CLIP_MODEL_ID = "google/siglip-so400m-patch14-384" + + model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE) + processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID) + + dataset = VideoInferenceDataset(video_infos, LOCAL_ROOT) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=1, + batch_size=12, + pin_memory=True, + # worker_init_fn=worker_init_fn, + ) + dataloader = iter(dataloader) + + smoother = LowessSmoother(smooth_fraction=0.02, iterations=1) + + clip_results_dict = defaultdict(list) + + print("Init model complete") + + batch_counter = 0 + MAX_BATCHES = 50 + + while batch_counter < MAX_BATCHES: + batch_counter += 1 + try: + start_time = time.time() + batch = next(dataloader) + # print(f"Fetch time: {time.time() - start_time:.2f} seconds") + except StopIteration: + break + + inputs = processor( + images=batch["frame"], + text=clip_text_prompts, + return_tensors="pt", + padding=True, + truncation=True, + ) + inputs = {k: v.to(DEVICE) for k, v in inputs.items()} + + outputs = model(**inputs) + + logits = outputs.logits_per_image + probs = torch.nn.functional.sigmoid(logits).detach().cpu().numpy() + + for video_idx, frame_idx, prob in zip( + batch["video_id"], batch["frame_idx"], probs + ): + # print(type(video_id.item()), type(frame_idx.item()), type(prob.item())) + video_id = video_infos[video_idx.item()].video_id + + clip_results_dict["video_id"].append(video_id) + clip_results_dict["frame_idx"].append(frame_idx.item()) + clip_results_dict["probs"].append(prob.item()) + + print("All frames processed") + clip_results = pd.DataFrame(clip_results_dict) + print("Dataframe created") + print(clip_results) + + max_gap_seconds = 1 + fps_sampling = 1 + min_prob = 0.1 + min_segment_seconds = 3 + fps = 25 + + segment_infos = [] + for video_id, video_clip_results in clip_results.groupby("video_id"): + probs = video_clip_results["probs"].values + probs = smoother.smooth(probs).smooth_data[0] + segments_start_end = get_segments( + probs, + max_gap=round(max_gap_seconds * fps_sampling), + min_prob=min_prob, + min_segment=round(min_segment_seconds * fps_sampling), + ) + + print(f"Segments for video {video_id}: {segments_start_end}") + + sec2ts = lambda s: time.strftime( + f"%H:%M:%S.{round((s%1)*1000):03d}", time.gmtime(s) + ) + + for start, end in segments_start_end: + segment_infos.append( + SegmentInfo( + start_timestamp=sec2ts(start), + end_timestamp=sec2ts(end), + fps=fps, + video_id=video_id, + ) + ) + + return segment_infos diff --git a/example_gui.py b/example_gui.py new file mode 100644 index 0000000..8b32c72 --- /dev/null +++ b/example_gui.py @@ -0,0 +1,428 @@ +import warnings +warnings.filterwarnings("ignore", message=".*TqdmWarning.*") +from dotenv import load_dotenv + +_ = load_dotenv() + +from langgraph.graph import StateGraph, END +from typing import TypedDict, Annotated, List +import operator +from langgraph.checkpoint.sqlite import SqliteSaver +from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage +from langchain_openai import ChatOpenAI +from langchain_core.pydantic_v1 import BaseModel +from tavily import TavilyClient +import os +import sqlite3 + +class AgentState(TypedDict): + task: str + lnode: str + plan: str + draft: str + critique: str + content: List[str] + queries: List[str] + revision_number: int + max_revisions: int + count: Annotated[int, operator.add] + + +class Queries(BaseModel): + queries: List[str] + +class ewriter(): + def __init__(self): + self.model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) + self.PLAN_PROMPT = ("You are an expert writer tasked with writing a high level outline of a short 3 paragraph essay. " + "Write such an outline for the user provided topic. Give the three main headers of an outline of " + "the essay along with any relevant notes or instructions for the sections. ") + self.WRITER_PROMPT = ("You are an essay assistant tasked with writing excellent 3 paragraph essays. " + "Generate the best essay possible for the user's request and the initial outline. " + "If the user provides critique, respond with a revised version of your previous attempts. " + "Utilize all the information below as needed: \n" + "------\n" + "{content}") + self.RESEARCH_PLAN_PROMPT = ("You are a researcher charged with providing information that can " + "be used when writing the following essay. Generate a list of search " + "queries that will gather " + "any relevant information. Only generate 3 queries max.") + self.REFLECTION_PROMPT = ("You are a teacher grading an 3 paragraph essay submission. " + "Generate critique and recommendations for the user's submission. " + "Provide detailed recommendations, including requests for length, depth, style, etc.") + self.RESEARCH_CRITIQUE_PROMPT = ("You are a researcher charged with providing information that can " + "be used when making any requested revisions (as outlined below). " + "Generate a list of search queries that will gather any relevant information. " + "Only generate 2 queries max.") + self.tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) + builder = StateGraph(AgentState) + builder.add_node("planner", self.plan_node) + builder.add_node("research_plan", self.research_plan_node) + builder.add_node("generate", self.generation_node) + builder.add_node("reflect", self.reflection_node) + builder.add_node("research_critique", self.research_critique_node) + builder.set_entry_point("planner") + builder.add_conditional_edges( + "generate", + self.should_continue, + {END: END, "reflect": "reflect"} + ) + builder.add_edge("planner", "research_plan") + builder.add_edge("research_plan", "generate") + builder.add_edge("reflect", "research_critique") + builder.add_edge("research_critique", "generate") + memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False)) + self.graph = builder.compile( + checkpointer=memory, + interrupt_after=['planner', 'generate', 'reflect', 'research_plan', 'research_critique'] + ) + + + def plan_node(self, state: AgentState): + messages = [ + SystemMessage(content=self.PLAN_PROMPT), + HumanMessage(content=state['task']) + ] + response = self.model.invoke(messages) + return {"plan": response.content, + "lnode": "planner", + "count": 1, + } + def research_plan_node(self, state: AgentState): + queries = self.model.with_structured_output(Queries).invoke([ + SystemMessage(content=self.RESEARCH_PLAN_PROMPT), + HumanMessage(content=state['task']) + ]) + content = state['content'] or [] # add to content + for q in queries.queries: + response = self.tavily.search(query=q, max_results=2) + for r in response['results']: + content.append(r['content']) + return {"content": content, + "queries": queries.queries, + "lnode": "research_plan", + "count": 1, + } + def generation_node(self, state: AgentState): + content = "\n\n".join(state['content'] or []) + user_message = HumanMessage( + content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}") + messages = [ + SystemMessage( + content=self.WRITER_PROMPT.format(content=content) + ), + user_message + ] + response = self.model.invoke(messages) + return { + "draft": response.content, + "revision_number": state.get("revision_number", 1) + 1, + "lnode": "generate", + "count": 1, + } + def reflection_node(self, state: AgentState): + messages = [ + SystemMessage(content=self.REFLECTION_PROMPT), + HumanMessage(content=state['draft']) + ] + response = self.model.invoke(messages) + return {"critique": response.content, + "lnode": "reflect", + "count": 1, + } + def research_critique_node(self, state: AgentState): + queries = self.model.with_structured_output(Queries).invoke([ + SystemMessage(content=self.RESEARCH_CRITIQUE_PROMPT), + HumanMessage(content=state['critique']) + ]) + content = state['content'] or [] + for q in queries.queries: + response = self.tavily.search(query=q, max_results=2) + for r in response['results']: + content.append(r['content']) + return {"content": content, + "lnode": "research_critique", + "count": 1, + } + def should_continue(self, state): + if state["revision_number"] > state["max_revisions"]: + return END + return "reflect" + +import gradio as gr +import time + +class writer_gui( ): + def __init__(self, graph, share=False): + self.graph = graph + self.share = share + self.partial_message = "" + self.response = {} + self.max_iterations = 10 + self.iterations = [] + self.threads = [] + self.thread_id = -1 + self.thread = {"configurable": {"thread_id": str(self.thread_id)}} + #self.sdisps = {} #global + self.demo = self.create_interface() + + def run_agent(self, start,topic,stop_after): + #global partial_message, thread_id,thread + #global response, max_iterations, iterations, threads + if start: + self.iterations.append(0) + config = {'task': topic,"max_revisions": 2,"revision_number": 0, + 'lnode': "", 'planner': "no plan", 'draft': "no draft", 'critique': "no critique", + 'content': ["no content",], 'queries': "no queries", 'count':0} + self.thread_id += 1 # new agent, new thread + self.threads.append(self.thread_id) + else: + config = None + self.thread = {"configurable": {"thread_id": str(self.thread_id)}} + while self.iterations[self.thread_id] < self.max_iterations: + self.response = self.graph.invoke(config, self.thread) + self.iterations[self.thread_id] += 1 + self.partial_message += str(self.response) + self.partial_message += f"\n------------------\n\n" + ## fix + lnode,nnode,_,rev,acount = self.get_disp_state() + yield self.partial_message,lnode,nnode,self.thread_id,rev,acount + config = None #need + #print(f"run_agent:{lnode}") + if not nnode: + #print("Hit the end") + return + if lnode in stop_after: + #print(f"stopping due to stop_after {lnode}") + return + else: + #print(f"Not stopping on lnode {lnode}") + pass + return + + def get_disp_state(self,): + current_state = self.graph.get_state(self.thread) + lnode = current_state.values["lnode"] + acount = current_state.values["count"] + rev = current_state.values["revision_number"] + nnode = current_state.next + #print (lnode,nnode,self.thread_id,rev,acount) + return lnode,nnode,self.thread_id,rev,acount + + def get_state(self,key): + current_values = self.graph.get_state(self.thread) + if key in current_values.values: + lnode,nnode,self.thread_id,rev,astep = self.get_disp_state() + new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}" + return gr.update(label=new_label, value=current_values.values[key]) + else: + return "" + + def get_content(self,): + current_values = self.graph.get_state(self.thread) + if "content" in current_values.values: + content = current_values.values["content"] + lnode,nnode,thread_id,rev,astep = self.get_disp_state() + new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}" + return gr.update(label=new_label, value="\n\n".join(item for item in content) + "\n\n") + else: + return "" + + def update_hist_pd(self,): + #print("update_hist_pd") + hist = [] + # curiously, this generator returns the latest first + for state in self.graph.get_state_history(self.thread): + if state.metadata['step'] < 1: + continue + thread_ts = state.config['configurable']['thread_ts'] + tid = state.config['configurable']['thread_id'] + count = state.values['count'] + lnode = state.values['lnode'] + rev = state.values['revision_number'] + nnode = state.next + st = f"{tid}:{count}:{lnode}:{nnode}:{rev}:{thread_ts}" + hist.append(st) + return gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts", + choices=hist, value=hist[0],interactive=True) + + def find_config(self,thread_ts): + for state in self.graph.get_state_history(self.thread): + config = state.config + if config['configurable']['thread_ts'] == thread_ts: + return config + return(None) + + def copy_state(self,hist_str): + ''' result of selecting an old state from the step pulldown. Note does not change thread. + This copies an old state to a new current state. + ''' + thread_ts = hist_str.split(":")[-1] + #print(f"copy_state from {thread_ts}") + config = self.find_config(thread_ts) + #print(config) + state = self.graph.get_state(config) + self.graph.update_state(self.thread, state.values, as_node=state.values['lnode']) + new_state = self.graph.get_state(self.thread) #should now match + new_thread_ts = new_state.config['configurable']['thread_ts'] + tid = new_state.config['configurable']['thread_id'] + count = new_state.values['count'] + lnode = new_state.values['lnode'] + rev = new_state.values['revision_number'] + nnode = new_state.next + return lnode,nnode,new_thread_ts,rev,count + + def update_thread_pd(self,): + #print("update_thread_pd") + return gr.Dropdown(label="choose thread", choices=threads, value=self.thread_id,interactive=True) + + def switch_thread(self,new_thread_id): + #print(f"switch_thread{new_thread_id}") + self.thread = {"configurable": {"thread_id": str(new_thread_id)}} + self.thread_id = new_thread_id + return + + def modify_state(self,key,asnode,new_state): + ''' gets the current state, modifes a single value in the state identified by key, and updates state with it. + note that this will create a new 'current state' node. If you do this multiple times with different keys, it will create + one for each update. Note also that it doesn't resume after the update + ''' + current_values = self.graph.get_state(self.thread) + current_values.values[key] = new_state + self.graph.update_state(self.thread, current_values.values,as_node=asnode) + return + + + def create_interface(self): + with gr.Blocks(theme=gr.themes.Default(spacing_size='sm',text_size="sm")) as demo: + + def updt_disp(): + ''' general update display on state change ''' + current_state = self.graph.get_state(self.thread) + hist = [] + # curiously, this generator returns the latest first + for state in self.graph.get_state_history(self.thread): + if state.metadata['step'] < 1: #ignore early states + continue + s_thread_ts = state.config['configurable']['thread_ts'] + s_tid = state.config['configurable']['thread_id'] + s_count = state.values['count'] + s_lnode = state.values['lnode'] + s_rev = state.values['revision_number'] + s_nnode = state.next + st = f"{s_tid}:{s_count}:{s_lnode}:{s_nnode}:{s_rev}:{s_thread_ts}" + hist.append(st) + if not current_state.metadata: #handle init call + return{} + else: + return { + topic_bx : current_state.values["task"], + lnode_bx : current_state.values["lnode"], + count_bx : current_state.values["count"], + revision_bx : current_state.values["revision_number"], + nnode_bx : current_state.next, + threadid_bx : self.thread_id, + thread_pd : gr.Dropdown(label="choose thread", choices=self.threads, value=self.thread_id,interactive=True), + step_pd : gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts", + choices=hist, value=hist[0],interactive=True), + } + def get_snapshots(): + new_label = f"thread_id: {self.thread_id}, Summary of snapshots" + sstate = "" + for state in self.graph.get_state_history(self.thread): + for key in ['plan', 'draft', 'critique']: + if key in state.values: + state.values[key] = state.values[key][:80] + "..." + if 'content' in state.values: + for i in range(len(state.values['content'])): + state.values['content'][i] = state.values['content'][i][:20] + '...' + if 'writes' in state.metadata: + state.metadata['writes'] = "not shown" + sstate += str(state) + "\n\n" + return gr.update(label=new_label, value=sstate) + + def vary_btn(stat): + #print(f"vary_btn{stat}") + return(gr.update(variant=stat)) + + with gr.Tab("Agent"): + with gr.Row(): + topic_bx = gr.Textbox(label="Essay Topic", value="Pizza Shop") + gen_btn = gr.Button("Generate Essay", scale=0,min_width=80, variant='primary') + cont_btn = gr.Button("Continue Essay", scale=0,min_width=80) + with gr.Row(): + lnode_bx = gr.Textbox(label="last node", min_width=100) + nnode_bx = gr.Textbox(label="next node", min_width=100) + threadid_bx = gr.Textbox(label="Thread", scale=0, min_width=80) + revision_bx = gr.Textbox(label="Draft Rev", scale=0, min_width=80) + count_bx = gr.Textbox(label="count", scale=0, min_width=80) + with gr.Accordion("Manage Agent", open=False): + checks = list(self.graph.nodes.keys()) + checks.remove('__start__') + stop_after = gr.CheckboxGroup(checks,label="Interrupt After State", value=checks, scale=0, min_width=400) + with gr.Row(): + thread_pd = gr.Dropdown(choices=self.threads,interactive=True, label="select thread", min_width=120, scale=0) + step_pd = gr.Dropdown(choices=['N/A'],interactive=True, label="select step", min_width=160, scale=1) + live = gr.Textbox(label="Live Agent Output", lines=5, max_lines=5) + + # actions + sdisps =[topic_bx,lnode_bx,nnode_bx,threadid_bx,revision_bx,count_bx,step_pd,thread_pd] + thread_pd.input(self.switch_thread, [thread_pd], None).then( + fn=updt_disp, inputs=None, outputs=sdisps) + step_pd.input(self.copy_state,[step_pd],None).then( + fn=updt_disp, inputs=None, outputs=sdisps) + gen_btn.click(vary_btn,gr.Number("secondary", visible=False), gen_btn).then( + fn=self.run_agent, inputs=[gr.Number(True, visible=False),topic_bx,stop_after], outputs=[live],show_progress=True).then( + fn=updt_disp, inputs=None, outputs=sdisps).then( + vary_btn,gr.Number("primary", visible=False), gen_btn).then( + vary_btn,gr.Number("primary", visible=False), cont_btn) + cont_btn.click(vary_btn,gr.Number("secondary", visible=False), cont_btn).then( + fn=self.run_agent, inputs=[gr.Number(False, visible=False),topic_bx,stop_after], + outputs=[live]).then( + fn=updt_disp, inputs=None, outputs=sdisps).then( + vary_btn,gr.Number("primary", visible=False), cont_btn) + + with gr.Tab("Plan"): + with gr.Row(): + refresh_btn = gr.Button("Refresh") + modify_btn = gr.Button("Modify") + plan = gr.Textbox(label="Plan", lines=10, interactive=True) + refresh_btn.click(fn=self.get_state, inputs=gr.Number("plan", visible=False), outputs=plan) + modify_btn.click(fn=self.modify_state, inputs=[gr.Number("plan", visible=False), + gr.Number("planner", visible=False), plan],outputs=None).then( + fn=updt_disp, inputs=None, outputs=sdisps) + with gr.Tab("Research Content"): + refresh_btn = gr.Button("Refresh") + content_bx = gr.Textbox(label="content", lines=10) + refresh_btn.click(fn=self.get_content, inputs=None, outputs=content_bx) + with gr.Tab("Draft"): + with gr.Row(): + refresh_btn = gr.Button("Refresh") + modify_btn = gr.Button("Modify") + draft_bx = gr.Textbox(label="draft", lines=10, interactive=True) + refresh_btn.click(fn=self.get_state, inputs=gr.Number("draft", visible=False), outputs=draft_bx) + modify_btn.click(fn=self.modify_state, inputs=[gr.Number("draft", visible=False), + gr.Number("generate", visible=False), draft_bx], outputs=None).then( + fn=updt_disp, inputs=None, outputs=sdisps) + with gr.Tab("Critique"): + with gr.Row(): + refresh_btn = gr.Button("Refresh") + modify_btn = gr.Button("Modify") + critique_bx = gr.Textbox(label="Critique", lines=10, interactive=True) + refresh_btn.click(fn=self.get_state, inputs=gr.Number("critique", visible=False), outputs=critique_bx) + modify_btn.click(fn=self.modify_state, inputs=[gr.Number("critique", visible=False), + gr.Number("reflect", visible=False), + critique_bx], outputs=None).then( + fn=updt_disp, inputs=None, outputs=sdisps) + with gr.Tab("StateSnapShots"): + with gr.Row(): + refresh_btn = gr.Button("Refresh") + snapshots = gr.Textbox(label="State Snapshots Summaries") + refresh_btn.click(fn=get_snapshots, inputs=None, outputs=snapshots) + return demo + + def launch(self, share=None): + if port := os.getenv("PORT1"): + self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0") + else: + self.demo.launch(share=self.share) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2e3aed1..3b41483 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ yt-dlp # openai langchain langchain-openai +langgraph pandas opencv-python scenedetect @@ -17,4 +18,5 @@ tsmoothie torch sentencepiece protobuf -transformers \ No newline at end of file +transformers +decord \ No newline at end of file