cellpose微调模型效果评估工具

为了对cellpose微调模型在测试数据集上的效果进行评估,在Grok的帮助下弄了一个小工具。

这个评估工具能够读取cellpose保存的npy文件,然后随机抽取指定数量的结构展示mask给专家遍历,然后由专家评估正确与否。专家遍历完成后自动计算正确率。专家的评估记录自动记录到csv表格文件中。而且这个工具还支持断点继续。具体使用效果如下:

2190654578.gif

完整代码如下:

import os
from glob import glob
import numpy as np
import tkinter as tk
from tkinter import messagebox, ttk, filedialog
import random
from PIL import Image, ImageTk
import csv
import json
from scipy.ndimage import center_of_mass
# 默认参数
DEFAULT_NUM_STRUCTURES = 200 # 默认检查200个结构
DEFAULT_CROP_SIZE = 200 # 默认裁剪区域大小(像素)
CONFIG_FILE = 'config.json' # 配置文件名
class LabelingApp:
def __init__(self, root):
self.root = root
self.root.title("Expert Checking")
self.results = []
self.num_structures = DEFAULT_NUM_STRUCTURES
self.crop_size = DEFAULT_CROP_SIZE
self.directory = None # 无默认目录
self.structures = []
self.current_structure_idx = 0
self.current_image = None
self.current_mask = None
self.current_label = None
self.current_file = None
# 加载已有结果
self.load_existing_results()
# 检查是否有配置文件
config_path = os.path.join(os.getcwd(), CONFIG_FILE) # 默认检查当前工作目录
if os.path.exists(config_path):
try:
self.load_config(config_path)
if not self.directory or not os.path.isdir(self.directory):
raise ValueError("Invalid directory in config")
self.load_existing_results() # 重新加载结果以确保使用配置中的目录
self.label_frame = tk.Frame(root)
self.canvas = tk.Canvas(self.label_frame, width=self.crop_size, height=self.crop_size)
self.canvas.pack()
self.label_var = tk.StringVar()
self.label_var.set("Structure 0/0")
tk.Label(self.label_frame, textvariable=self.label_var).pack()
# 按钮框架用于居中对齐
button_frame = tk.Frame(self.label_frame)
button_frame.pack(pady=10)
tk.Button(button_frame, text="Correct", command=self.label_correct, width=10).pack(side=tk.LEFT, padx=10, fill=tk.X, expand=True)
tk.Button(button_frame, text="Incorrect", command=self.label_incorrect, width=10).pack(side=tk.LEFT, padx=10, fill=tk.X, expand=True)
self.label_frame.pack()
self.load_structures()
except (json.JSONDecodeError, ValueError) as e:
messagebox.showerror("Error", f"Invalid config.json or directory: {str(e)}. Please select a directory.")
self.show_setup_gui()
else:
self.show_setup_gui()
def show_setup_gui(self):
# GUI 设置窗口,每项配置单独一行
self.setup_frame = tk.Frame(self.root)
self.setup_frame.pack(pady=10)
# Directory 输入和选择按钮
dir_frame = tk.Frame(self.setup_frame)
dir_frame.pack(fill=tk.X, pady=5)
tk.Label(dir_frame, text="Directory:").pack(anchor=tk.W)
dir_inner_frame = tk.Frame(dir_frame)
dir_inner_frame.pack(fill=tk.X)
self.dir_entry = tk.Entry(dir_inner_frame, width=50)
self.dir_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 5))
tk.Button(dir_inner_frame, text="Browse", command=self.browse_directory).pack(side=tk.LEFT)
# Number of Structures 输入
num_frame = tk.Frame(self.setup_frame)
num_frame.pack(fill=tk.X, pady=5)
tk.Label(num_frame, text="Number of Structures (-1 for all):").pack(anchor=tk.W)
self.num_entry = tk.Entry(num_frame, width=50)
self.num_entry.insert(0, str(DEFAULT_NUM_STRUCTURES))
self.num_entry.pack(fill=tk.X, padx=5)
# Crop Size 输入
crop_frame = tk.Frame(self.setup_frame)
crop_frame.pack(fill=tk.X, pady=5)
tk.Label(crop_frame, text="Crop Size (pixels):").pack(anchor=tk.W)
self.crop_entry = tk.Entry(crop_frame, width=50)
self.crop_entry.insert(0, str(DEFAULT_CROP_SIZE))
self.crop_entry.pack(fill=tk.X, padx=5)
# Start 按钮
tk.Button(self.setup_frame, text="Start", command=self.start_labeling).pack(pady=10)
def browse_directory(self):
# 打开目录选择对话框
directory = filedialog.askdirectory()
if directory:
self.dir_entry.delete(0, tk.END)
self.dir_entry.insert(0, directory)
def load_existing_results(self):
# 加载已有CSV文件
self.results = []
if self.directory and os.path.exists(self.directory):
csv_path = os.path.join(self.directory, 'labeling_results.csv')
if os.path.exists(csv_path):
with open(csv_path, 'r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
self.results.append({
'file': row['file'],
'structure_id': int(row['structure_id']),
'expert_label': row['expert_label'] == 'True'
})
def load_config(self, config_path):
# 加载配置文件
with open(config_path, 'r') as f:
config = json.load(f)
self.directory = config['directory']
self.num_structures = config['num_structures']
self.crop_size = config['crop_size']
self.structures = [(s[0], int(s[1]), s[2], None) for s in config['structures']]
def save_config(self):
# 保存配置文件
config = {
'directory': str(self.directory), # 确保字符串
'num_structures': int(self.num_structures), # 确保整数
'crop_size': int(self.crop_size), # 确保整数
'structures': [(str(s[0]), int(s[1]), str(s[2])) for s in self.structures] # 转换为Python原生类型
}
config_path = os.path.join(self.directory, CONFIG_FILE)
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
def start_labeling(self):
try:
self.directory = self.dir_entry.get()
self.num_structures = int(self.num_entry.get())
self.crop_size = int(self.crop_entry.get())
if not self.directory:
raise ValueError("Please select a directory")
if self.crop_size <= 0:
raise ValueError("Crop size must be positive")
if not os.path.isdir(self.directory):
raise ValueError("Invalid directory")
except ValueError as e:
messagebox.showerror("Error", str(e))
return
# 重新加载已有结果以确保使用新选择的目录
self.load_existing_results()
# 调整画布大小并切换到标签窗口
self.setup_frame.pack_forget()
self.label_frame = tk.Frame(self.root)
self.canvas = tk.Canvas(self.label_frame, width=self.crop_size, height=self.crop_size)
self.canvas.pack()
self.label_var = tk.StringVar()
self.label_var.set("Structure 0/0")
tk.Label(self.label_frame, textvariable=self.label_var).pack()
# 按钮框架用于居中对齐
button_frame = tk.Frame(self.label_frame)
button_frame.pack(pady=10)
tk.Button(button_frame, text="Correct", command=self.label_correct, width=10).pack(side=tk.LEFT, padx=10, fill=tk.X, expand=True)
tk.Button(button_frame, text="Incorrect", command=self.label_incorrect, width=10).pack(side=tk.LEFT, padx=10, fill=tk.X, expand=True)
self.label_frame.pack()
self.load_structures()
def load_structures(self):
# 获取指定目录下所有npy文件
fps = glob(os.path.join(self.directory, '*.npy'))
if not fps:
messagebox.showerror("Error", "No .npy files found in the specified directory")
self.root.quit()
return
# 重置结构列表
self.structures = []
# 收集所有结构
all_structures = []
for fp in fps:
try:
rec = np.load(fp, allow_pickle=True).item()
masks = rec.get('masks')
if masks is None:
continue # 跳过没有masks的文件
unique_labels = np.unique(masks)
structures = [label for label in unique_labels if label != 0]
for label in structures:
all_structures.append((fp, label, rec.get('filename', ''), masks))
except Exception as e:
messagebox.showwarning("Warning", f"Error loading {fp}: {str(e)}")
continue
if not all_structures:
messagebox.showerror("Error", "No valid structures found in .npy files")
self.root.quit()
return
# 过滤掉已标签的结构
labeled_set = {(r['file'], r['structure_id']) for r in self.results}
remaining_structures = [
s for s in all_structures
if (s[0], s[1]) not in labeled_set
]
# 随机选择结构
if self.num_structures == -1 or self.num_structures >= len(remaining_structures):
self.structures = remaining_structures
else:
self.structures = random.sample(remaining_structures, self.num_structures)
random.shuffle(self.structures)
# 保存配置
self.save_config()
if not self.structures:
messagebox.showerror("Error", "No structures available to label after filtering.")
self.root.quit()
return
self.show_structure()
def show_structure(self):
if self.current_structure_idx >= len(self.structures):
self.save_results()
messagebox.showinfo("Done", f"All structures processed. Results saved.\nAccuracy: {self.calculate_accuracy():.2%}")
self.root.quit()
return
self.current_file, self.current_label, img_path, self.current_mask = self.structures[self.current_structure_idx]
try:
self.current_image = np.array(Image.open(img_path).convert('RGB'))
except FileNotFoundError:
messagebox.showwarning("Warning", f"Image file {img_path} not found, skipping.")
self.current_structure_idx += 1
self.show_structure()
return
self.label_var.set(f"Structure {self.current_structure_idx + 1}/{len(self.structures)}")
# 计算结构质心并裁剪
mask = (self.current_mask == self.current_label).astype(np.uint8)
centroid_y, centroid_x = center_of_mass(mask)
centroid_y, centroid_x = int(centroid_y), int(centroid_x)
half_crop = self.crop_size // 2
# 计算裁剪区域
h, w = self.current_mask.shape
x_start = centroid_x - half_crop
x_end = centroid_x + half_crop
y_start = centroid_y - half_crop
y_end = centroid_y + half_crop
# 创建全黑背景
cropped_image = np.zeros((self.crop_size, self.crop_size, 3), dtype=np.uint8)
cropped_mask = np.zeros((self.crop_size, self.crop_size), dtype=np.uint8)
# 计算图像内的有效裁剪区域
src_x_start = max(0, x_start)
src_x_end = min(w, x_end)
src_y_start = max(0, y_start)
src_y_end = min(h, y_end)
# 计算目标区域的偏移量
dst_x_start = max(0, -x_start)
dst_x_end = min(self.crop_size, w - x_start)
dst_y_start = max(0, -y_start)
dst_y_end = min(self.crop_size, h - y_start)
# 复制有效区域到目标图像
if src_x_end > src_x_start and src_y_end > src_y_start:
cropped_image[dst_y_start:dst_y_end, dst_x_start:dst_x_end] = \
self.current_image[src_y_start:src_y_end, src_x_start:src_x_end]
cropped_mask[dst_y_start:dst_y_end, dst_x_start:dst_x_end] = \
mask[src_y_start:src_y_end, src_x_start:src_x_end]
# 创建掩码叠加图像,透明度为0.1
mask_rgb = np.stack([cropped_mask * 0, cropped_mask * 255, cropped_mask * 0], axis=-1) # 绿色掩码
overlay = np.clip(cropped_image * 0.9 + mask_rgb * 0.1, 0, 255).astype(np.uint8)
img = Image.fromarray(overlay)
self.photo = ImageTk.PhotoImage(img)
self.canvas.create_image(0, 0, anchor=tk.NW, image=self.photo)
def label_correct(self):
self.save_label(True)
self.current_structure_idx += 1
self.show_structure()
def label_incorrect(self):
self.save_label(False)
self.current_structure_idx += 1
self.show_structure()
def save_label(self, expert_label):
self.results.append({
'file': self.current_file,
'structure_id': self.current_label,
'expert_label': expert_label
})
# 增量保存到CSV
self.save_results()
def calculate_accuracy(self):
if not self.results:
return 0.0
correct = sum(1 for r in self.results if r['expert_label'])
return correct / len(self.results)
def save_results(self):
# 保存结果到CSV
csv_path = os.path.join(self.directory, 'labeling_results.csv')
with open(csv_path, 'w', newline='') as csvfile:
fieldnames = ['file', 'structure_id', 'expert_label']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for result in self.results:
writer.writerow(result)
# 主程序
if __name__ == "__main__":
root = tk.Tk()
app = LabelingApp(root)
root.mainloop()

使用时需要注意的就是运行它的python环境。最好是使用cellpose(目前支持版本为4.0.1)的相同python环境运行该程序,避免对cellpose保存的npy文件读取错误。