Files
dark_hal/hf_downloader.py
2026-03-13 12:56:43 -07:00

586 lines
23 KiB
Python

import os
import requests
import tkinter as tk
from tkinter import ttk, messagebox, filedialog
from typing import Optional, List, Dict, Any
import threading
import json
from pathlib import Path
from dotenv import load_dotenv
# Load the HuggingFace API key from HUGGINGFACE.env
load_dotenv("HUGGINGFACE.env")
class HuggingFaceAPI:
"""Web API-based HuggingFace interface using direct HTTP requests."""
def __init__(self, api_key: Optional[str] = None, organization: Optional[str] = None):
# Ensure API key is properly cleaned of whitespace and newlines
raw_key = api_key or os.getenv("HF_API_KEY", "")
self.api_key = raw_key.strip().replace('\n', '').replace('\r', '')
if not self.api_key:
raise ValueError("HuggingFace API key not found. Please set HF_API_KEY in HUGGINGFACE.env")
self.organization = organization.strip() if organization else None
self.base_url = "https://huggingface.co"
self.headers = {"Authorization": f"Bearer {self.api_key}"}
# Add organization header if specified
if self.organization:
self.headers["X-Organization"] = self.organization
def search_models(self, query: str = "", limit: int = 50, sort: str = "downloads") -> List[Dict[str, Any]]:
"""Search for models using the web API."""
url = f"{self.base_url}/api/models"
params = {
"limit": limit,
"sort": sort,
"direction": -1,
"full": True
}
if query:
params["search"] = query
try:
response = requests.get(url, params=params, headers=self.headers, timeout=30)
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Error searching models: {e}")
return []
def search_datasets(self, query: str = "", limit: int = 50, sort: str = "downloads") -> List[Dict[str, Any]]:
"""Search for datasets using the web API."""
url = f"{self.base_url}/api/datasets"
params = {
"limit": limit,
"sort": sort,
"direction": -1,
"full": True
}
if query:
params["search"] = query
try:
response = requests.get(url, params=params, headers=self.headers, timeout=30)
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Error searching datasets: {e}")
return []
def get_model_files(self, repo_id: str) -> List[Dict[str, Any]]:
"""Get list of files in a model repository."""
url = f"{self.base_url}/api/models/{repo_id}"
try:
response = requests.get(url, headers=self.headers, timeout=30)
response.raise_for_status()
data = response.json()
return data.get("siblings", [])
except Exception as e:
print(f"Error getting model files: {e}")
return []
def download_file(self, repo_id: str, filename: str, save_path: str,
progress_callback=None) -> bool:
"""Download a file from HuggingFace."""
url = f"{self.base_url}/{repo_id}/resolve/main/{filename}"
try:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
response = requests.get(url, headers=self.headers, stream=True, timeout=30)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if progress_callback and total_size > 0:
progress = (downloaded / total_size) * 100
progress_callback(progress, downloaded, total_size)
return True
except Exception as e:
print(f"Error downloading file: {e}")
return False
class HuggingFaceDownloaderGUI:
"""GUI for HuggingFace model and dataset search and download."""
def __init__(self, root: tk.Tk):
self.root = root
self.root.title("HuggingFace Downloader")
self.root.geometry("1200x700")
# Initialize API
try:
self.api = HuggingFaceAPI()
except ValueError as e:
messagebox.showerror("API Key Error", str(e))
self.api = None
# Search variables
self.search_query = tk.StringVar()
self.search_type = tk.StringVar(value="Models")
self.filter_most_downloaded = tk.BooleanVar(value=True)
self.filter_most_liked = tk.BooleanVar(value=False)
self.filter_size = tk.BooleanVar(value=False)
# Current results storage
self.current_results = []
self._build_ui()
def _build_ui(self):
"""Build the main UI."""
# Search bar frame
search_frame = ttk.Frame(self.root)
search_frame.pack(fill=tk.X, padx=10, pady=10)
# Search entry
self.search_entry = ttk.Entry(search_frame, textvariable=self.search_query, width=60)
self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
self.search_entry.bind("<Return>", lambda e: self._perform_search())
# Search type dropdown
self.type_dropdown = ttk.Combobox(search_frame, textvariable=self.search_type,
values=["Models", "Datasets"],
state="readonly", width=15)
self.type_dropdown.pack(side=tk.LEFT, padx=(10, 0))
# Search button
self.search_button = ttk.Button(search_frame, text="Search", command=self._perform_search)
self.search_button.pack(side=tk.LEFT, padx=(10, 0))
# Results frame with treeview
results_frame = ttk.Frame(self.root)
results_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=(0, 10))
# Create treeview with columns
columns = ("creator", "name", "description", "keywords", "size", "metadata")
self.results_tree = ttk.Treeview(results_frame, columns=columns, show="headings", height=20)
# Define column headings and widths
self.results_tree.heading("creator", text="Creator")
self.results_tree.heading("name", text="Name")
self.results_tree.heading("description", text="Description")
self.results_tree.heading("keywords", text="Keywords")
self.results_tree.heading("size", text="Size")
self.results_tree.heading("metadata", text="Metadata")
self.results_tree.column("creator", width=150)
self.results_tree.column("name", width=200)
self.results_tree.column("description", width=300)
self.results_tree.column("keywords", width=150)
self.results_tree.column("size", width=100)
self.results_tree.column("metadata", width=200)
# Scrollbars
vsb = ttk.Scrollbar(results_frame, orient="vertical", command=self.results_tree.yview)
hsb = ttk.Scrollbar(results_frame, orient="horizontal", command=self.results_tree.xview)
self.results_tree.configure(yscrollcommand=vsb.set, xscrollcommand=hsb.set)
self.results_tree.grid(row=0, column=0, sticky="nsew")
vsb.grid(row=0, column=1, sticky="ns")
hsb.grid(row=1, column=0, sticky="ew")
results_frame.grid_rowconfigure(0, weight=1)
results_frame.grid_columnconfigure(0, weight=1)
# Bind double-click to download
self.results_tree.bind("<Double-Button-1>", self._on_item_double_click)
# Filter footer frame
filter_frame = ttk.Frame(self.root)
filter_frame.pack(fill=tk.X, padx=10, pady=(0, 10))
ttk.Label(filter_frame, text="Filter:").pack(side=tk.LEFT)
ttk.Checkbutton(filter_frame, text="Most Downloaded",
variable=self.filter_most_downloaded,
command=self._update_filters).pack(side=tk.LEFT, padx=(10, 0))
ttk.Checkbutton(filter_frame, text="Most Liked",
variable=self.filter_most_liked,
command=self._update_filters).pack(side=tk.LEFT, padx=(10, 0))
ttk.Checkbutton(filter_frame, text="Size",
variable=self.filter_size,
command=self._update_filters).pack(side=tk.LEFT, padx=(10, 0))
# Status bar
self.status_var = tk.StringVar(value="Ready")
status_bar = ttk.Label(self.root, textvariable=self.status_var, relief=tk.SUNKEN)
status_bar.pack(fill=tk.X, side=tk.BOTTOM)
# Download button
download_frame = ttk.Frame(self.root)
download_frame.pack(fill=tk.X, padx=10, pady=(0, 10))
self.download_button = ttk.Button(download_frame, text="Download Selected",
command=self._download_selected)
self.download_button.pack(side=tk.RIGHT)
def _update_filters(self):
"""Update filter settings and re-sort results if needed."""
# Ensure at least one filter is selected
if not any([self.filter_most_downloaded.get(),
self.filter_most_liked.get(),
self.filter_size.get()]):
self.filter_most_downloaded.set(True)
def _perform_search(self):
"""Perform the search based on current settings."""
if not self.api:
messagebox.showerror("Error", "API not initialized")
return
query = self.search_query.get().strip()
search_type = self.search_type.get()
# Determine sort parameter
sort = "downloads"
if self.filter_most_liked.get() and not self.filter_most_downloaded.get():
sort = "likes"
elif self.filter_size.get() and not self.filter_most_downloaded.get() and not self.filter_most_liked.get():
sort = "lastModified"
self.status_var.set(f"Searching {search_type.lower()}...")
self.search_button.config(state="disabled")
# Clear previous results
for item in self.results_tree.get_children():
self.results_tree.delete(item)
# Perform search in thread
thread = threading.Thread(target=self._search_thread,
args=(query, search_type, sort))
thread.daemon = True
thread.start()
def _search_thread(self, query: str, search_type: str, sort: str):
"""Thread function for performing search."""
try:
if search_type == "Models":
results = self.api.search_models(query, limit=50, sort=sort)
else:
results = self.api.search_datasets(query, limit=50, sort=sort)
self.current_results = results
# Update UI in main thread
self.root.after(0, self._populate_results, results, search_type)
except Exception as e:
self.root.after(0, lambda: messagebox.showerror("Search Error", str(e)))
finally:
self.root.after(0, lambda: self.search_button.config(state="normal"))
def _populate_results(self, results: List[Dict], search_type: str):
"""Populate the treeview with search results."""
count = 0
for item in results:
try:
# Extract common fields
if search_type == "Models":
repo_id = item.get("modelId", item.get("id", ""))
pipeline_tag = item.get("pipeline_tag", "")
tags = item.get("tags", [])
keywords = ", ".join(tags[:3]) if tags else pipeline_tag
else:
repo_id = item.get("id", "")
task_ids = item.get("cardData", {}).get("task_ids", [])
keywords = ", ".join(task_ids[:3]) if task_ids else "dataset"
creator = repo_id.split("/")[0] if "/" in repo_id else ""
name = repo_id.split("/")[1] if "/" in repo_id else repo_id
# Get description
description = ""
if search_type == "Models":
description = item.get("description", "")
else:
card_data = item.get("cardData", {})
description = card_data.get("description", card_data.get("summary", ""))
# Truncate description
if len(description) > 100:
description = description[:97] + "..."
# Calculate size
size_bytes = 0
siblings = item.get("siblings", [])
for sibling in siblings:
if isinstance(sibling, dict):
size = sibling.get("size", 0)
if isinstance(size, (int, float)):
size_bytes += size
size_str = self._format_size(size_bytes) if size_bytes > 0 else "-"
# Get metadata
metadata_parts = []
downloads = item.get("downloads", 0)
likes = item.get("likes", 0)
if downloads > 0:
metadata_parts.append(f"{self._format_number(downloads)}")
if likes > 0:
metadata_parts.append(f"{self._format_number(likes)}")
if search_type == "Models":
library = item.get("library_name", "")
if library:
metadata_parts.append(library)
metadata = " | ".join(metadata_parts)
# Insert into treeview
self.results_tree.insert("", tk.END, values=(
creator, name, description, keywords, size_str, metadata
))
count += 1
except Exception as e:
print(f"Error processing result: {e}")
continue
self.status_var.set(f"Found {count} {search_type.lower()}")
def _format_size(self, bytes_size: int) -> str:
"""Format bytes to human readable size."""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_size < 1024.0:
return f"{bytes_size:.1f} {unit}"
bytes_size /= 1024.0
return f"{bytes_size:.1f} PB"
def _format_number(self, num: int) -> str:
"""Format large numbers with K, M suffixes."""
if num >= 1_000_000:
return f"{num/1_000_000:.1f}M"
elif num >= 1_000:
return f"{num/1_000:.1f}K"
return str(num)
def _on_item_double_click(self, event):
"""Handle double-click on a result item."""
self._download_selected()
def _download_selected(self):
"""Download the selected model or dataset."""
selection = self.results_tree.selection()
if not selection:
messagebox.showinfo("No Selection", "Please select an item to download")
return
item = self.results_tree.item(selection[0])
values = item['values']
if len(values) < 2:
return
creator = values[0]
name = values[1]
repo_id = f"{creator}/{name}" if creator else name
# Ask for download location
download_dir = filedialog.askdirectory(title="Select Download Directory")
if not download_dir:
return
# Create download window
self._show_download_window(repo_id, download_dir)
def _show_download_window(self, repo_id: str, download_dir: str):
"""Show a window for selecting files to download."""
download_window = tk.Toplevel(self.root)
download_window.title(f"Download: {repo_id}")
download_window.geometry("800x500")
# Get files list
ttk.Label(download_window, text="Fetching file list...").pack(pady=10)
def fetch_files():
files = self.api.get_model_files(repo_id)
download_window.after(0, lambda: self._populate_download_window(
download_window, repo_id, download_dir, files))
thread = threading.Thread(target=fetch_files)
thread.daemon = True
thread.start()
def _populate_download_window(self, window: tk.Toplevel, repo_id: str,
download_dir: str, files: List[Dict]):
"""Populate the download window with file list."""
# Clear window
for widget in window.winfo_children():
widget.destroy()
ttk.Label(window, text=f"Select files to download from {repo_id}:").pack(pady=5)
# File list frame
list_frame = ttk.Frame(window)
list_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# Create treeview for files
columns = ("filename", "size")
file_tree = ttk.Treeview(list_frame, columns=columns, show="tree headings", height=15)
file_tree.heading("#0", text="Select")
file_tree.heading("filename", text="File")
file_tree.heading("size", text="Size")
file_tree.column("#0", width=50)
file_tree.column("filename", width=500)
file_tree.column("size", width=100)
vsb = ttk.Scrollbar(list_frame, orient="vertical", command=file_tree.yview)
file_tree.configure(yscrollcommand=vsb.set)
file_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
vsb.pack(side=tk.RIGHT, fill=tk.Y)
# Add files to tree with checkboxes
file_vars = {}
for file_info in files:
filename = file_info.get("rfilename", "")
size = file_info.get("size", 0)
size_str = self._format_size(size) if size > 0 else "-"
item_id = file_tree.insert("", tk.END, text="",
values=(filename, size_str))
file_vars[item_id] = {"filename": filename, "selected": False}
# Toggle selection on click
def toggle_selection(event):
item = file_tree.identify("item", event.x, event.y)
if item in file_vars:
file_vars[item]["selected"] = not file_vars[item]["selected"]
check = "" if file_vars[item]["selected"] else ""
file_tree.item(item, text=check)
file_tree.bind("<Button-1>", toggle_selection)
# Button frame
button_frame = ttk.Frame(window)
button_frame.pack(fill=tk.X, padx=10, pady=10)
def select_all():
for item_id in file_vars:
file_vars[item_id]["selected"] = True
file_tree.item(item_id, text="")
def select_none():
for item_id in file_vars:
file_vars[item_id]["selected"] = False
file_tree.item(item_id, text="")
def select_gguf():
for item_id in file_vars:
filename = file_vars[item_id]["filename"]
is_gguf = filename.lower().endswith(".gguf")
file_vars[item_id]["selected"] = is_gguf
file_tree.item(item_id, text="" if is_gguf else "")
ttk.Button(button_frame, text="Select All", command=select_all).pack(side=tk.LEFT, padx=5)
ttk.Button(button_frame, text="Select None", command=select_none).pack(side=tk.LEFT, padx=5)
ttk.Button(button_frame, text="Select GGUF Only", command=select_gguf).pack(side=tk.LEFT, padx=5)
def start_download():
selected_files = [info["filename"] for info in file_vars.values() if info["selected"]]
if not selected_files:
messagebox.showinfo("No Selection", "Please select at least one file to download")
return
window.destroy()
self._download_files(repo_id, selected_files, download_dir)
ttk.Button(button_frame, text="Download Selected",
command=start_download).pack(side=tk.RIGHT, padx=5)
ttk.Button(button_frame, text="Cancel",
command=window.destroy).pack(side=tk.RIGHT, padx=5)
def _download_files(self, repo_id: str, files: List[str], download_dir: str):
"""Download selected files."""
# Create progress window
progress_window = tk.Toplevel(self.root)
progress_window.title("Downloading...")
progress_window.geometry("500x200")
ttk.Label(progress_window, text=f"Downloading from {repo_id}").pack(pady=10)
current_file_var = tk.StringVar(value="Preparing...")
ttk.Label(progress_window, textvariable=current_file_var).pack(pady=5)
progress_var = tk.DoubleVar()
progress_bar = ttk.Progressbar(progress_window, variable=progress_var,
maximum=100, length=400)
progress_bar.pack(pady=10)
status_var = tk.StringVar(value="Starting download...")
ttk.Label(progress_window, textvariable=status_var).pack(pady=5)
cancel_flag = {"cancelled": False}
def cancel_download():
cancel_flag["cancelled"] = True
progress_window.destroy()
ttk.Button(progress_window, text="Cancel", command=cancel_download).pack(pady=10)
def download_thread():
total_files = len(files)
completed = 0
for filename in files:
if cancel_flag["cancelled"]:
break
current_file_var.set(f"Downloading: {filename}")
save_path = os.path.join(download_dir, repo_id.replace("/", "_"), filename)
def update_progress(percent, downloaded, total):
progress_var.set(percent)
size_str = f"{self._format_size(downloaded)} / {self._format_size(total)}"
status_var.set(f"File {completed + 1}/{total_files}: {size_str}")
success = self.api.download_file(repo_id, filename, save_path, update_progress)
if success:
completed += 1
if cancel_flag["cancelled"]:
break
if not cancel_flag["cancelled"]:
progress_window.after(0, lambda: messagebox.showinfo(
"Download Complete",
f"Downloaded {completed}/{total_files} files to {download_dir}"))
progress_window.after(0, progress_window.destroy)
thread = threading.Thread(target=download_thread)
thread.daemon = True
thread.start()
def main():
"""Main entry point for the HuggingFace Downloader GUI."""
root = tk.Tk()
app = HuggingFaceDownloaderGUI(root)
root.mainloop()
if __name__ == "__main__":
main()