first commit
This commit is contained in:
337
finetune_tab.py
Normal file
337
finetune_tab.py
Normal file
@@ -0,0 +1,337 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fine Tune Tab for DarkHal 2.0
|
||||
|
||||
Model fine-tuning interface for training and customizing AI models.
|
||||
"""
|
||||
|
||||
import tkinter as tk
|
||||
from tkinter import ttk, messagebox, scrolledtext, filedialog
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
|
||||
class FineTuneTab:
|
||||
"""Fine-tuning interface for training and customizing AI models."""
|
||||
|
||||
def __init__(self, parent: ttk.Frame, settings_manager):
|
||||
self.parent = parent
|
||||
self.settings = settings_manager
|
||||
self.current_model = None
|
||||
self.training_in_progress = False
|
||||
|
||||
# Create main frame
|
||||
self.main_frame = ttk.Frame(parent)
|
||||
self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||
|
||||
# Create fine-tuning interface
|
||||
self._create_finetune_interface()
|
||||
|
||||
def _create_finetune_interface(self):
|
||||
"""Create the main fine-tuning interface."""
|
||||
|
||||
# Model Selection Frame
|
||||
model_frame = ttk.LabelFrame(self.main_frame, text="Model Selection", padding=10)
|
||||
model_frame.pack(fill=tk.X, pady=(0, 10))
|
||||
|
||||
# Base model selection
|
||||
ttk.Label(model_frame, text="Base Model:").grid(row=0, column=0, sticky=tk.W, pady=5)
|
||||
self.base_model_var = tk.StringVar()
|
||||
self.base_model_entry = ttk.Entry(model_frame, textvariable=self.base_model_var, width=50)
|
||||
self.base_model_entry.grid(row=0, column=1, padx=5)
|
||||
ttk.Button(model_frame, text="Browse", command=self._browse_base_model).grid(row=0, column=2, padx=5)
|
||||
|
||||
# Output model name
|
||||
ttk.Label(model_frame, text="Output Model Name:").grid(row=1, column=0, sticky=tk.W, pady=5)
|
||||
self.output_model_var = tk.StringVar(value="my-finetuned-model")
|
||||
ttk.Entry(model_frame, textvariable=self.output_model_var, width=50).grid(row=1, column=1, padx=5)
|
||||
|
||||
# Training Data Frame
|
||||
data_frame = ttk.LabelFrame(self.main_frame, text="Training Data", padding=10)
|
||||
data_frame.pack(fill=tk.X, pady=(0, 10))
|
||||
|
||||
# Dataset file
|
||||
ttk.Label(data_frame, text="Dataset File:").grid(row=0, column=0, sticky=tk.W, pady=5)
|
||||
self.dataset_var = tk.StringVar()
|
||||
ttk.Entry(data_frame, textvariable=self.dataset_var, width=50).grid(row=0, column=1, padx=5)
|
||||
ttk.Button(data_frame, text="Browse", command=self._browse_dataset).grid(row=0, column=2, padx=5)
|
||||
|
||||
# Dataset format
|
||||
ttk.Label(data_frame, text="Dataset Format:").grid(row=1, column=0, sticky=tk.W, pady=5)
|
||||
self.format_var = tk.StringVar(value="alpaca")
|
||||
format_combo = ttk.Combobox(data_frame, textvariable=self.format_var,
|
||||
values=["alpaca", "sharegpt", "completion", "chat", "custom"],
|
||||
state="readonly", width=20)
|
||||
format_combo.grid(row=1, column=1, sticky=tk.W, padx=5)
|
||||
|
||||
# Training split
|
||||
ttk.Label(data_frame, text="Train/Val Split:").grid(row=2, column=0, sticky=tk.W, pady=5)
|
||||
self.split_var = tk.StringVar(value="90/10")
|
||||
ttk.Combobox(data_frame, textvariable=self.split_var,
|
||||
values=["80/20", "90/10", "95/5", "100/0"],
|
||||
state="readonly", width=20).grid(row=2, column=1, sticky=tk.W, padx=5)
|
||||
|
||||
# Training Parameters Frame
|
||||
params_frame = ttk.LabelFrame(self.main_frame, text="Training Parameters", padding=10)
|
||||
params_frame.pack(fill=tk.X, pady=(0, 10))
|
||||
|
||||
# Create two columns for parameters
|
||||
left_params = ttk.Frame(params_frame)
|
||||
left_params.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
||||
right_params = ttk.Frame(params_frame)
|
||||
right_params.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
|
||||
|
||||
# Left column parameters
|
||||
ttk.Label(left_params, text="Training Method:").grid(row=0, column=0, sticky=tk.W, pady=5)
|
||||
self.method_var = tk.StringVar(value="LoRA")
|
||||
ttk.Combobox(left_params, textvariable=self.method_var,
|
||||
values=["LoRA", "QLoRA", "Full Fine-tune", "PEFT"],
|
||||
state="readonly", width=15).grid(row=0, column=1, padx=5)
|
||||
|
||||
ttk.Label(left_params, text="Epochs:").grid(row=1, column=0, sticky=tk.W, pady=5)
|
||||
self.epochs_var = tk.StringVar(value="3")
|
||||
ttk.Spinbox(left_params, from_=1, to=100, textvariable=self.epochs_var, width=15).grid(row=1, column=1, padx=5)
|
||||
|
||||
ttk.Label(left_params, text="Batch Size:").grid(row=2, column=0, sticky=tk.W, pady=5)
|
||||
self.batch_var = tk.StringVar(value="4")
|
||||
ttk.Spinbox(left_params, from_=1, to=64, textvariable=self.batch_var, width=15).grid(row=2, column=1, padx=5)
|
||||
|
||||
ttk.Label(left_params, text="Learning Rate:").grid(row=3, column=0, sticky=tk.W, pady=5)
|
||||
self.lr_var = tk.StringVar(value="2e-4")
|
||||
ttk.Entry(left_params, textvariable=self.lr_var, width=15).grid(row=3, column=1, padx=5)
|
||||
|
||||
# Right column parameters
|
||||
ttk.Label(right_params, text="LoRA Rank:").grid(row=0, column=0, sticky=tk.W, pady=5)
|
||||
self.lora_rank_var = tk.StringVar(value="8")
|
||||
ttk.Spinbox(right_params, from_=4, to=128, textvariable=self.lora_rank_var, width=15).grid(row=0, column=1, padx=5)
|
||||
|
||||
ttk.Label(right_params, text="LoRA Alpha:").grid(row=1, column=0, sticky=tk.W, pady=5)
|
||||
self.lora_alpha_var = tk.StringVar(value="16")
|
||||
ttk.Spinbox(right_params, from_=8, to=256, textvariable=self.lora_alpha_var, width=15).grid(row=1, column=1, padx=5)
|
||||
|
||||
ttk.Label(right_params, text="Max Length:").grid(row=2, column=0, sticky=tk.W, pady=5)
|
||||
self.max_length_var = tk.StringVar(value="512")
|
||||
ttk.Spinbox(right_params, from_=128, to=4096, increment=128, textvariable=self.max_length_var, width=15).grid(row=2, column=1, padx=5)
|
||||
|
||||
ttk.Label(right_params, text="Warmup Steps:").grid(row=3, column=0, sticky=tk.W, pady=5)
|
||||
self.warmup_var = tk.StringVar(value="100")
|
||||
ttk.Spinbox(right_params, from_=0, to=1000, textvariable=self.warmup_var, width=15).grid(row=3, column=1, padx=5)
|
||||
|
||||
# Hardware Settings Frame
|
||||
hardware_frame = ttk.LabelFrame(self.main_frame, text="Hardware Settings", padding=10)
|
||||
hardware_frame.pack(fill=tk.X, pady=(0, 10))
|
||||
|
||||
ttk.Label(hardware_frame, text="Device:").grid(row=0, column=0, sticky=tk.W, pady=5)
|
||||
self.device_var = tk.StringVar(value="cuda")
|
||||
ttk.Combobox(hardware_frame, textvariable=self.device_var,
|
||||
values=["cuda", "cpu", "mps"],
|
||||
state="readonly", width=15).grid(row=0, column=1, padx=5)
|
||||
|
||||
ttk.Label(hardware_frame, text="Mixed Precision:").grid(row=0, column=2, sticky=tk.W, padx=(20, 5))
|
||||
self.mixed_precision_var = tk.BooleanVar(value=True)
|
||||
ttk.Checkbutton(hardware_frame, variable=self.mixed_precision_var).grid(row=0, column=3)
|
||||
|
||||
ttk.Label(hardware_frame, text="Gradient Checkpointing:").grid(row=0, column=4, sticky=tk.W, padx=(20, 5))
|
||||
self.grad_checkpoint_var = tk.BooleanVar(value=True)
|
||||
ttk.Checkbutton(hardware_frame, variable=self.grad_checkpoint_var).grid(row=0, column=5)
|
||||
|
||||
# Control Buttons Frame
|
||||
control_frame = ttk.Frame(self.main_frame)
|
||||
control_frame.pack(fill=tk.X, pady=(0, 10))
|
||||
|
||||
ttk.Button(control_frame, text="Start Training", command=self._start_training,
|
||||
style="Accent.TButton").pack(side=tk.LEFT, padx=5)
|
||||
ttk.Button(control_frame, text="Stop Training", command=self._stop_training).pack(side=tk.LEFT, padx=5)
|
||||
ttk.Button(control_frame, text="Save Config", command=self._save_config).pack(side=tk.LEFT, padx=5)
|
||||
ttk.Button(control_frame, text="Load Config", command=self._load_config).pack(side=tk.LEFT, padx=5)
|
||||
ttk.Button(control_frame, text="Validate Dataset", command=self._validate_dataset).pack(side=tk.LEFT, padx=5)
|
||||
|
||||
# Progress Frame
|
||||
progress_frame = ttk.LabelFrame(self.main_frame, text="Training Progress", padding=10)
|
||||
progress_frame.pack(fill=tk.BOTH, expand=True)
|
||||
|
||||
# Progress bar
|
||||
self.progress_var = tk.DoubleVar()
|
||||
self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100)
|
||||
self.progress_bar.pack(fill=tk.X, pady=(0, 10))
|
||||
|
||||
# Status label
|
||||
self.status_label = ttk.Label(progress_frame, text="Ready to start training", foreground="green")
|
||||
self.status_label.pack(anchor=tk.W, pady=(0, 10))
|
||||
|
||||
# Training log
|
||||
ttk.Label(progress_frame, text="Training Log:", font=("Arial", 10, "bold")).pack(anchor=tk.W)
|
||||
self.log_text = scrolledtext.ScrolledText(progress_frame, height=10, width=80, state=tk.DISABLED)
|
||||
self.log_text.pack(fill=tk.BOTH, expand=True)
|
||||
|
||||
def _browse_base_model(self):
|
||||
"""Browse for base model file."""
|
||||
filename = filedialog.askopenfilename(
|
||||
title="Select Base Model",
|
||||
filetypes=[("GGUF files", "*.gguf"), ("All files", "*.*")]
|
||||
)
|
||||
if filename:
|
||||
self.base_model_var.set(filename)
|
||||
self._log(f"Selected base model: {filename}")
|
||||
|
||||
def _browse_dataset(self):
|
||||
"""Browse for dataset file."""
|
||||
filename = filedialog.askopenfilename(
|
||||
title="Select Dataset",
|
||||
filetypes=[
|
||||
("JSON files", "*.json"),
|
||||
("JSONL files", "*.jsonl"),
|
||||
("CSV files", "*.csv"),
|
||||
("Text files", "*.txt"),
|
||||
("All files", "*.*")
|
||||
]
|
||||
)
|
||||
if filename:
|
||||
self.dataset_var.set(filename)
|
||||
self._log(f"Selected dataset: {filename}")
|
||||
|
||||
def _start_training(self):
|
||||
"""Start the fine-tuning process."""
|
||||
if self.training_in_progress:
|
||||
messagebox.showwarning("Training in Progress", "Training is already in progress!")
|
||||
return
|
||||
|
||||
# Validate inputs
|
||||
if not self.base_model_var.get():
|
||||
messagebox.showerror("Error", "Please select a base model")
|
||||
return
|
||||
|
||||
if not self.dataset_var.get():
|
||||
messagebox.showerror("Error", "Please select a dataset")
|
||||
return
|
||||
|
||||
self.training_in_progress = True
|
||||
self.status_label.config(text="Training in progress...", foreground="orange")
|
||||
self._log("=" * 50)
|
||||
self._log("Starting fine-tuning process...")
|
||||
self._log(f"Base Model: {self.base_model_var.get()}")
|
||||
self._log(f"Dataset: {self.dataset_var.get()}")
|
||||
self._log(f"Method: {self.method_var.get()}")
|
||||
self._log(f"Epochs: {self.epochs_var.get()}")
|
||||
self._log("=" * 50)
|
||||
|
||||
# TODO: Implement actual training logic
|
||||
messagebox.showinfo("Coming Soon", "Fine-tuning functionality will be implemented soon!")
|
||||
|
||||
self.training_in_progress = False
|
||||
self.status_label.config(text="Training complete (stub)", foreground="green")
|
||||
|
||||
def _stop_training(self):
|
||||
"""Stop the training process."""
|
||||
if not self.training_in_progress:
|
||||
messagebox.showinfo("No Training", "No training in progress")
|
||||
return
|
||||
|
||||
self.training_in_progress = False
|
||||
self.status_label.config(text="Training stopped", foreground="red")
|
||||
self._log("Training stopped by user")
|
||||
|
||||
def _save_config(self):
|
||||
"""Save training configuration to file."""
|
||||
config = {
|
||||
"base_model": self.base_model_var.get(),
|
||||
"output_model": self.output_model_var.get(),
|
||||
"dataset": self.dataset_var.get(),
|
||||
"format": self.format_var.get(),
|
||||
"split": self.split_var.get(),
|
||||
"method": self.method_var.get(),
|
||||
"epochs": self.epochs_var.get(),
|
||||
"batch_size": self.batch_var.get(),
|
||||
"learning_rate": self.lr_var.get(),
|
||||
"lora_rank": self.lora_rank_var.get(),
|
||||
"lora_alpha": self.lora_alpha_var.get(),
|
||||
"max_length": self.max_length_var.get(),
|
||||
"warmup_steps": self.warmup_var.get(),
|
||||
"device": self.device_var.get(),
|
||||
"mixed_precision": self.mixed_precision_var.get(),
|
||||
"gradient_checkpointing": self.grad_checkpoint_var.get()
|
||||
}
|
||||
|
||||
filename = filedialog.asksaveasfilename(
|
||||
title="Save Training Config",
|
||||
defaultextension=".json",
|
||||
filetypes=[("JSON files", "*.json"), ("All files", "*.*")]
|
||||
)
|
||||
|
||||
if filename:
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
self._log(f"Config saved to {filename}")
|
||||
messagebox.showinfo("Saved", f"Configuration saved to {filename}")
|
||||
|
||||
def _load_config(self):
|
||||
"""Load training configuration from file."""
|
||||
filename = filedialog.askopenfilename(
|
||||
title="Load Training Config",
|
||||
filetypes=[("JSON files", "*.json"), ("All files", "*.*")]
|
||||
)
|
||||
|
||||
if filename:
|
||||
try:
|
||||
with open(filename, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Load values from config
|
||||
self.base_model_var.set(config.get("base_model", ""))
|
||||
self.output_model_var.set(config.get("output_model", "my-finetuned-model"))
|
||||
self.dataset_var.set(config.get("dataset", ""))
|
||||
self.format_var.set(config.get("format", "alpaca"))
|
||||
self.split_var.set(config.get("split", "90/10"))
|
||||
self.method_var.set(config.get("method", "LoRA"))
|
||||
self.epochs_var.set(config.get("epochs", "3"))
|
||||
self.batch_var.set(config.get("batch_size", "4"))
|
||||
self.lr_var.set(config.get("learning_rate", "2e-4"))
|
||||
self.lora_rank_var.set(config.get("lora_rank", "8"))
|
||||
self.lora_alpha_var.set(config.get("lora_alpha", "16"))
|
||||
self.max_length_var.set(config.get("max_length", "512"))
|
||||
self.warmup_var.set(config.get("warmup_steps", "100"))
|
||||
self.device_var.set(config.get("device", "cuda"))
|
||||
self.mixed_precision_var.set(config.get("mixed_precision", True))
|
||||
self.grad_checkpoint_var.set(config.get("gradient_checkpointing", True))
|
||||
|
||||
self._log(f"Config loaded from {filename}")
|
||||
messagebox.showinfo("Loaded", f"Configuration loaded from {filename}")
|
||||
|
||||
except Exception as e:
|
||||
messagebox.showerror("Error", f"Failed to load config: {e}")
|
||||
|
||||
def _validate_dataset(self):
|
||||
"""Validate the selected dataset."""
|
||||
if not self.dataset_var.get():
|
||||
messagebox.showerror("Error", "Please select a dataset first")
|
||||
return
|
||||
|
||||
dataset_path = self.dataset_var.get()
|
||||
if not os.path.exists(dataset_path):
|
||||
messagebox.showerror("Error", f"Dataset file not found: {dataset_path}")
|
||||
return
|
||||
|
||||
# TODO: Implement actual dataset validation
|
||||
self._log(f"Validating dataset: {dataset_path}")
|
||||
self._log("Dataset validation (stub) - would check format, size, etc.")
|
||||
messagebox.showinfo("Validation", "Dataset validation complete (stub)")
|
||||
|
||||
def _log(self, message: str):
|
||||
"""Add message to training log."""
|
||||
self.log_text.config(state=tk.NORMAL)
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
self.log_text.insert(tk.END, f"[{timestamp}] {message}\n")
|
||||
self.log_text.see(tk.END)
|
||||
self.log_text.config(state=tk.DISABLED)
|
||||
|
||||
def set_model(self, model_path: Optional[str]):
|
||||
"""Set the current model for fine-tuning."""
|
||||
self.current_model = model_path
|
||||
if model_path:
|
||||
self.base_model_var.set(model_path)
|
||||
self._log(f"Model selected: {Path(model_path).name}")
|
||||
Reference in New Issue
Block a user