Files
dark_hal/finetune_tab.py

337 lines
16 KiB
Python
Raw Normal View History

2026-03-13 12:56:43 -07:00
#!/usr/bin/env python3
"""
Fine Tune Tab for DarkHal 2.0
Model fine-tuning interface for training and customizing AI models.
"""
import tkinter as tk
from tkinter import ttk, messagebox, scrolledtext, filedialog
import os
import sys
import json
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, Any, List
class FineTuneTab:
"""Fine-tuning interface for training and customizing AI models."""
def __init__(self, parent: ttk.Frame, settings_manager):
self.parent = parent
self.settings = settings_manager
self.current_model = None
self.training_in_progress = False
# Create main frame
self.main_frame = ttk.Frame(parent)
self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# Create fine-tuning interface
self._create_finetune_interface()
def _create_finetune_interface(self):
"""Create the main fine-tuning interface."""
# Model Selection Frame
model_frame = ttk.LabelFrame(self.main_frame, text="Model Selection", padding=10)
model_frame.pack(fill=tk.X, pady=(0, 10))
# Base model selection
ttk.Label(model_frame, text="Base Model:").grid(row=0, column=0, sticky=tk.W, pady=5)
self.base_model_var = tk.StringVar()
self.base_model_entry = ttk.Entry(model_frame, textvariable=self.base_model_var, width=50)
self.base_model_entry.grid(row=0, column=1, padx=5)
ttk.Button(model_frame, text="Browse", command=self._browse_base_model).grid(row=0, column=2, padx=5)
# Output model name
ttk.Label(model_frame, text="Output Model Name:").grid(row=1, column=0, sticky=tk.W, pady=5)
self.output_model_var = tk.StringVar(value="my-finetuned-model")
ttk.Entry(model_frame, textvariable=self.output_model_var, width=50).grid(row=1, column=1, padx=5)
# Training Data Frame
data_frame = ttk.LabelFrame(self.main_frame, text="Training Data", padding=10)
data_frame.pack(fill=tk.X, pady=(0, 10))
# Dataset file
ttk.Label(data_frame, text="Dataset File:").grid(row=0, column=0, sticky=tk.W, pady=5)
self.dataset_var = tk.StringVar()
ttk.Entry(data_frame, textvariable=self.dataset_var, width=50).grid(row=0, column=1, padx=5)
ttk.Button(data_frame, text="Browse", command=self._browse_dataset).grid(row=0, column=2, padx=5)
# Dataset format
ttk.Label(data_frame, text="Dataset Format:").grid(row=1, column=0, sticky=tk.W, pady=5)
self.format_var = tk.StringVar(value="alpaca")
format_combo = ttk.Combobox(data_frame, textvariable=self.format_var,
values=["alpaca", "sharegpt", "completion", "chat", "custom"],
state="readonly", width=20)
format_combo.grid(row=1, column=1, sticky=tk.W, padx=5)
# Training split
ttk.Label(data_frame, text="Train/Val Split:").grid(row=2, column=0, sticky=tk.W, pady=5)
self.split_var = tk.StringVar(value="90/10")
ttk.Combobox(data_frame, textvariable=self.split_var,
values=["80/20", "90/10", "95/5", "100/0"],
state="readonly", width=20).grid(row=2, column=1, sticky=tk.W, padx=5)
# Training Parameters Frame
params_frame = ttk.LabelFrame(self.main_frame, text="Training Parameters", padding=10)
params_frame.pack(fill=tk.X, pady=(0, 10))
# Create two columns for parameters
left_params = ttk.Frame(params_frame)
left_params.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
right_params = ttk.Frame(params_frame)
right_params.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
# Left column parameters
ttk.Label(left_params, text="Training Method:").grid(row=0, column=0, sticky=tk.W, pady=5)
self.method_var = tk.StringVar(value="LoRA")
ttk.Combobox(left_params, textvariable=self.method_var,
values=["LoRA", "QLoRA", "Full Fine-tune", "PEFT"],
state="readonly", width=15).grid(row=0, column=1, padx=5)
ttk.Label(left_params, text="Epochs:").grid(row=1, column=0, sticky=tk.W, pady=5)
self.epochs_var = tk.StringVar(value="3")
ttk.Spinbox(left_params, from_=1, to=100, textvariable=self.epochs_var, width=15).grid(row=1, column=1, padx=5)
ttk.Label(left_params, text="Batch Size:").grid(row=2, column=0, sticky=tk.W, pady=5)
self.batch_var = tk.StringVar(value="4")
ttk.Spinbox(left_params, from_=1, to=64, textvariable=self.batch_var, width=15).grid(row=2, column=1, padx=5)
ttk.Label(left_params, text="Learning Rate:").grid(row=3, column=0, sticky=tk.W, pady=5)
self.lr_var = tk.StringVar(value="2e-4")
ttk.Entry(left_params, textvariable=self.lr_var, width=15).grid(row=3, column=1, padx=5)
# Right column parameters
ttk.Label(right_params, text="LoRA Rank:").grid(row=0, column=0, sticky=tk.W, pady=5)
self.lora_rank_var = tk.StringVar(value="8")
ttk.Spinbox(right_params, from_=4, to=128, textvariable=self.lora_rank_var, width=15).grid(row=0, column=1, padx=5)
ttk.Label(right_params, text="LoRA Alpha:").grid(row=1, column=0, sticky=tk.W, pady=5)
self.lora_alpha_var = tk.StringVar(value="16")
ttk.Spinbox(right_params, from_=8, to=256, textvariable=self.lora_alpha_var, width=15).grid(row=1, column=1, padx=5)
ttk.Label(right_params, text="Max Length:").grid(row=2, column=0, sticky=tk.W, pady=5)
self.max_length_var = tk.StringVar(value="512")
ttk.Spinbox(right_params, from_=128, to=4096, increment=128, textvariable=self.max_length_var, width=15).grid(row=2, column=1, padx=5)
ttk.Label(right_params, text="Warmup Steps:").grid(row=3, column=0, sticky=tk.W, pady=5)
self.warmup_var = tk.StringVar(value="100")
ttk.Spinbox(right_params, from_=0, to=1000, textvariable=self.warmup_var, width=15).grid(row=3, column=1, padx=5)
# Hardware Settings Frame
hardware_frame = ttk.LabelFrame(self.main_frame, text="Hardware Settings", padding=10)
hardware_frame.pack(fill=tk.X, pady=(0, 10))
ttk.Label(hardware_frame, text="Device:").grid(row=0, column=0, sticky=tk.W, pady=5)
self.device_var = tk.StringVar(value="cuda")
ttk.Combobox(hardware_frame, textvariable=self.device_var,
values=["cuda", "cpu", "mps"],
state="readonly", width=15).grid(row=0, column=1, padx=5)
ttk.Label(hardware_frame, text="Mixed Precision:").grid(row=0, column=2, sticky=tk.W, padx=(20, 5))
self.mixed_precision_var = tk.BooleanVar(value=True)
ttk.Checkbutton(hardware_frame, variable=self.mixed_precision_var).grid(row=0, column=3)
ttk.Label(hardware_frame, text="Gradient Checkpointing:").grid(row=0, column=4, sticky=tk.W, padx=(20, 5))
self.grad_checkpoint_var = tk.BooleanVar(value=True)
ttk.Checkbutton(hardware_frame, variable=self.grad_checkpoint_var).grid(row=0, column=5)
# Control Buttons Frame
control_frame = ttk.Frame(self.main_frame)
control_frame.pack(fill=tk.X, pady=(0, 10))
ttk.Button(control_frame, text="Start Training", command=self._start_training,
style="Accent.TButton").pack(side=tk.LEFT, padx=5)
ttk.Button(control_frame, text="Stop Training", command=self._stop_training).pack(side=tk.LEFT, padx=5)
ttk.Button(control_frame, text="Save Config", command=self._save_config).pack(side=tk.LEFT, padx=5)
ttk.Button(control_frame, text="Load Config", command=self._load_config).pack(side=tk.LEFT, padx=5)
ttk.Button(control_frame, text="Validate Dataset", command=self._validate_dataset).pack(side=tk.LEFT, padx=5)
# Progress Frame
progress_frame = ttk.LabelFrame(self.main_frame, text="Training Progress", padding=10)
progress_frame.pack(fill=tk.BOTH, expand=True)
# Progress bar
self.progress_var = tk.DoubleVar()
self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100)
self.progress_bar.pack(fill=tk.X, pady=(0, 10))
# Status label
self.status_label = ttk.Label(progress_frame, text="Ready to start training", foreground="green")
self.status_label.pack(anchor=tk.W, pady=(0, 10))
# Training log
ttk.Label(progress_frame, text="Training Log:", font=("Arial", 10, "bold")).pack(anchor=tk.W)
self.log_text = scrolledtext.ScrolledText(progress_frame, height=10, width=80, state=tk.DISABLED)
self.log_text.pack(fill=tk.BOTH, expand=True)
def _browse_base_model(self):
"""Browse for base model file."""
filename = filedialog.askopenfilename(
title="Select Base Model",
filetypes=[("GGUF files", "*.gguf"), ("All files", "*.*")]
)
if filename:
self.base_model_var.set(filename)
self._log(f"Selected base model: {filename}")
def _browse_dataset(self):
"""Browse for dataset file."""
filename = filedialog.askopenfilename(
title="Select Dataset",
filetypes=[
("JSON files", "*.json"),
("JSONL files", "*.jsonl"),
("CSV files", "*.csv"),
("Text files", "*.txt"),
("All files", "*.*")
]
)
if filename:
self.dataset_var.set(filename)
self._log(f"Selected dataset: {filename}")
def _start_training(self):
"""Start the fine-tuning process."""
if self.training_in_progress:
messagebox.showwarning("Training in Progress", "Training is already in progress!")
return
# Validate inputs
if not self.base_model_var.get():
messagebox.showerror("Error", "Please select a base model")
return
if not self.dataset_var.get():
messagebox.showerror("Error", "Please select a dataset")
return
self.training_in_progress = True
self.status_label.config(text="Training in progress...", foreground="orange")
self._log("=" * 50)
self._log("Starting fine-tuning process...")
self._log(f"Base Model: {self.base_model_var.get()}")
self._log(f"Dataset: {self.dataset_var.get()}")
self._log(f"Method: {self.method_var.get()}")
self._log(f"Epochs: {self.epochs_var.get()}")
self._log("=" * 50)
# TODO: Implement actual training logic
messagebox.showinfo("Coming Soon", "Fine-tuning functionality will be implemented soon!")
self.training_in_progress = False
self.status_label.config(text="Training complete (stub)", foreground="green")
def _stop_training(self):
"""Stop the training process."""
if not self.training_in_progress:
messagebox.showinfo("No Training", "No training in progress")
return
self.training_in_progress = False
self.status_label.config(text="Training stopped", foreground="red")
self._log("Training stopped by user")
def _save_config(self):
"""Save training configuration to file."""
config = {
"base_model": self.base_model_var.get(),
"output_model": self.output_model_var.get(),
"dataset": self.dataset_var.get(),
"format": self.format_var.get(),
"split": self.split_var.get(),
"method": self.method_var.get(),
"epochs": self.epochs_var.get(),
"batch_size": self.batch_var.get(),
"learning_rate": self.lr_var.get(),
"lora_rank": self.lora_rank_var.get(),
"lora_alpha": self.lora_alpha_var.get(),
"max_length": self.max_length_var.get(),
"warmup_steps": self.warmup_var.get(),
"device": self.device_var.get(),
"mixed_precision": self.mixed_precision_var.get(),
"gradient_checkpointing": self.grad_checkpoint_var.get()
}
filename = filedialog.asksaveasfilename(
title="Save Training Config",
defaultextension=".json",
filetypes=[("JSON files", "*.json"), ("All files", "*.*")]
)
if filename:
with open(filename, 'w') as f:
json.dump(config, f, indent=2)
self._log(f"Config saved to {filename}")
messagebox.showinfo("Saved", f"Configuration saved to {filename}")
def _load_config(self):
"""Load training configuration from file."""
filename = filedialog.askopenfilename(
title="Load Training Config",
filetypes=[("JSON files", "*.json"), ("All files", "*.*")]
)
if filename:
try:
with open(filename, 'r') as f:
config = json.load(f)
# Load values from config
self.base_model_var.set(config.get("base_model", ""))
self.output_model_var.set(config.get("output_model", "my-finetuned-model"))
self.dataset_var.set(config.get("dataset", ""))
self.format_var.set(config.get("format", "alpaca"))
self.split_var.set(config.get("split", "90/10"))
self.method_var.set(config.get("method", "LoRA"))
self.epochs_var.set(config.get("epochs", "3"))
self.batch_var.set(config.get("batch_size", "4"))
self.lr_var.set(config.get("learning_rate", "2e-4"))
self.lora_rank_var.set(config.get("lora_rank", "8"))
self.lora_alpha_var.set(config.get("lora_alpha", "16"))
self.max_length_var.set(config.get("max_length", "512"))
self.warmup_var.set(config.get("warmup_steps", "100"))
self.device_var.set(config.get("device", "cuda"))
self.mixed_precision_var.set(config.get("mixed_precision", True))
self.grad_checkpoint_var.set(config.get("gradient_checkpointing", True))
self._log(f"Config loaded from {filename}")
messagebox.showinfo("Loaded", f"Configuration loaded from {filename}")
except Exception as e:
messagebox.showerror("Error", f"Failed to load config: {e}")
def _validate_dataset(self):
"""Validate the selected dataset."""
if not self.dataset_var.get():
messagebox.showerror("Error", "Please select a dataset first")
return
dataset_path = self.dataset_var.get()
if not os.path.exists(dataset_path):
messagebox.showerror("Error", f"Dataset file not found: {dataset_path}")
return
# TODO: Implement actual dataset validation
self._log(f"Validating dataset: {dataset_path}")
self._log("Dataset validation (stub) - would check format, size, etc.")
messagebox.showinfo("Validation", "Dataset validation complete (stub)")
def _log(self, message: str):
"""Add message to training log."""
self.log_text.config(state=tk.NORMAL)
timestamp = datetime.now().strftime("%H:%M:%S")
self.log_text.insert(tk.END, f"[{timestamp}] {message}\n")
self.log_text.see(tk.END)
self.log_text.config(state=tk.DISABLED)
def set_model(self, model_path: Optional[str]):
"""Set the current model for fine-tuning."""
self.current_model = model_path
if model_path:
self.base_model_var.set(model_path)
self._log(f"Model selected: {Path(model_path).name}")