基于深度学习的工业质检全流程开发指南
一、工业缺陷检测技术选型
主流工业视觉检测方案对比:
技术方案 | 准确率 | 处理速度 | 硬件成本 |
---|---|---|---|
传统图像处理 | 70-85% | 快(50-100ms) | 低 |
机器学习 | 85-92% | 中(100-300ms) | 中 |
深度学习 | 93-99% | 慢(300-1000ms) | 高 |
混合方案 | 90-97% | 中快(150-500ms) | 中高 |
二、系统架构设计
1. 智能检测系统架构
图像采集 → 预处理 → 缺陷检测 → 分类定位 → 结果输出 ↑ ↑ ↑ ↑ ↑ 工业相机 OpenCV处理 PyTorch模型 后处理分析 MES系统对接
2. 模型训练流程
数据收集 → 标注清洗 → 增强扩充 → 模型训练 → 量化压缩 → 部署上线
↑ ↑ ↑ ↑ ↑ ↑
产线采集 Labelme工具 Albumentations 迁移学习 TensorRT优化 Flask服务
三、核心模块实现
1. 工业图像预处理
import cv2
import numpy as np
from albumentations import (
Compose, Rotate, RandomBrightnessContrast,
GaussNoise, OpticalDistortion
)
class ImagePreprocessor:
def __init__(self, target_size=(512, 512)):
self.target_size = target_size
self.augmentations = Compose([
Rotate(limit=15, p=0.5),
RandomBrightnessContrast(p=0.3),
GaussNoise(var_limit=(10, 50), p=0.2),
OpticalDistortion(p=0.1)
])
def process(self, image_path, is_training=False):
# 读取并标准化图像
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 统一尺寸和归一化
image = cv2.resize(image, self.target_size)
image = image.astype(np.float32) / 255.0
# 训练时数据增强
if is_training:
augmented = self.augmentations(image=image)
image = augmented['image']
# 标准化处理
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = (image - mean) / std
# 转换为PyTorch张量格式
image = np.transpose(image, (2, 0, 1))
return torch.from_numpy(image)
def adaptive_threshold(self, image):
"""自适应阈值处理"""
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
return cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2
)
def remove_background(self, image):
"""基于形态学的背景去除"""
kernel = np.ones((5,5), np.uint8)
opening = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel, iterations=2)
return cv2.dilate(opening, kernel, iterations=3)
2. 缺陷检测模型
import torch
import torch.nn as nn
from torchvision.models import resnet34
class DefectDetector(nn.Module):
def __init__(self, num_classes=5):
super().__init__()
# 使用预训练的ResNet34作为骨干网络
self.backbone = resnet34(pretrained=True)
# 替换最后的全连接层
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity() # 移除原始全连接层
# 添加自定义头部
self.classifier = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
# 缺陷定位分支
self.locator = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
# 骨干网络特征提取
features = self.backbone(x)
# 分类预测
cls_output = self.classifier(features)
# 获取中间特征图用于定位
mid_features = self.backbone.layer4(x)
loc_output = self.locator(mid_features)
return cls_output, loc_output
class DefectLoss(nn.Module):
def __init__(self, alpha=0.7):
super().__init__()
self.alpha = alpha
self.cls_loss = nn.CrossEntropyLoss()
self.loc_loss = nn.BCELoss()
def forward(self, outputs, targets):
cls_pred, loc_pred = outputs
cls_target, loc_target = targets
# 分类损失
loss_cls = self.cls_loss(cls_pred, cls_target)
# 定位损失
loss_loc = self.loc_loss(loc_pred, loc_target)
# 组合损失
return self.alpha * loss_cls + (1 - self.alpha) * loss_loc
四、高级功能实现
1. 小样本学习增强
import torch
from torch.utils.data import Dataset
from sklearn.utils.class_weight import compute_class_weight
class FewShotDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
# 计算类别权重
self.class_weights = compute_class_weight(
'balanced',
classes=np.unique(labels),
y=labels
)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
class FewShotSampler:
def __init__(self, dataset, n_way=5, k_shot=3, query_num=5):
self.dataset = dataset
self.n_way = n_way # 每批N个类别
self.k_shot = k_shot # 每类K个样本
self.query_num = query_num # 查询样本数
# 按类别组织样本索引
self.class_indices = {}
for idx, (_, label) in enumerate(dataset):
if label not in self.class_indices:
self.class_indices[label] = []
self.class_indices[label].append(idx)
self.classes = list(self.class_indices.keys())
def __iter__(self):
while True:
# 随机选择N个类别
selected_classes = np.random.choice(
self.classes, self.n_way, replace=False)
support_set = []
query_set = []
for cls in selected_classes:
# 从当前类中随机选择K+Q个样本
samples = np.random.choice(
self.class_indices[cls],
self.k_shot + self.query_num,
replace=False)
# 前K个作为支持集
support_set.extend(samples[:self.k_shot])
# 后Q个作为查询集
query_set.extend(samples[self.k_shot:])
yield support_set + query_set
2. 模型解释性可视化
import matplotlib.pyplot as plt
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz
class ModelExplainer:
def __init__(self, model):
self.model = model
self.ig = IntegratedGradients(model)
def explain(self, input_tensor, target_class, n_steps=50):
# 计算归因
attributions = self.ig.attribute(
input_tensor,
target=target_class,
n_steps=n_steps
)
# 转换为适合可视化的格式
img = input_tensor.squeeze().permute(1, 2, 0).cpu().detach().numpy()
attr = attributions.squeeze().permute(1, 2, 0).cpu().detach().numpy()
# 可视化
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# 原始图像
ax[0].imshow(img)
ax[0].set_title('Original Image')
ax[0].axis('off')
# 归因热力图
viz.visualize_image_attr(
np.transpose(attr, (2, 0, 1)),
original_image=img,
method='blended_heat_map',
sign='absolute_value',
show_colorbar=True,
title='Feature Importance',
plt_fig_axis=(fig, ax[1])
)
return fig
def generate_saliency_map(self, input_tensor, target_class):
"""生成显著性图"""
input_tensor.requires_grad_()
# 前向传播
output = self.model(input_tensor.unsqueeze(0))
output[0, target_class].backward()
# 获取梯度
saliency = input_tensor.grad.abs().max(dim=0)[0]
# 可视化
plt.figure(figsize=(6, 6))
plt.imshow(saliency.cpu(), cmap='hot')
plt.colorbar()
plt.title('Saliency Map')
plt.axis('off')
return saliency
五、生产部署优化
1. 模型量化与加速
import torch
from torch.quantization import quantize_dynamic
from torch.utils.mobile_optimizer import optimize_for_mobile
class ModelOptimizer:
def __init__(self, model):
self.model = model.eval()
def dynamic_quantization(self):
"""动态量化模型"""
quantized_model = quantize_dynamic(
self.model,
{torch.nn.Linear, torch.nn.Conv2d},
dtype=torch.qint8
)
return quantized_model
def static_quantization(self, calibration_data):
"""静态量化模型"""
quantized_model = torch.quantization.quantize_qat(
self.model,
{torch.nn.Linear, torch.nn.Conv2d},
inplace=False
)
# 校准
with torch.no_grad():
for data in calibration_data:
_ = quantized_model(data)
return torch.quantization.convert(quantized_model)
def optimize_for_mobile(self, output_path):
"""优化移动端部署"""
scripted_model = torch.jit.script(self.model)
optimized_model = optimize_for_mobile(scripted_model)
optimized_model._save_for_lite_interpreter(output_path)
return optimized_model
def convert_to_onnx(self, output_path, input_shape=(1, 3, 512, 512)):
"""转换为ONNX格式"""
dummy_input = torch.randn(*input_shape)
torch.onnx.export(
self.model,
dummy_input,
output_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
def tensorrt_optimization(self, onnx_path, output_path):
"""使用TensorRT优化"""
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# 解析ONNX模型
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError("ONNX解析失败")
# 构建配置
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
# 构建引擎
serialized_engine = builder.build_serialized_network(network, config)
with open(output_path, 'wb') as f:
f.write(serialized_engine)
2. 高性能推理服务
from fastapi import FastAPI, File, UploadFile
import uvicorn
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import io
app = FastAPI()
model = None
executor = ThreadPoolExecutor(max_workers=4)
def load_model():
"""加载预训练模型"""
global model
model = DefectDetector()
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
@app.on_event("startup")
async def startup_event():
load_model()
async def process_image(image_bytes):
"""异步处理图像"""
loop = asyncio.get_event_loop()
# 在单独线程中运行CPU密集型任务
image = await loop.run_in_executor(
executor,
preprocess_image,
image_bytes
)
# 在事件循环中运行GPU推理
with torch.no_grad():
image = image.to('cuda' if torch.cuda.is_available() else 'cpu')
cls_output, loc_output = model(image.unsqueeze(0))
return {
'class': torch.argmax(cls_output).item(),
'confidence': torch.max(torch.softmax(cls_output, dim=1)).item(),
'heatmap': loc_output.squeeze().cpu().numpy().tolist()
}
def preprocess_image(image_bytes):
"""图像预处理"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
preprocessor = ImagePreprocessor()
tensor = preprocessor.process(image)
return tensor
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""预测API端点"""
try:
image_bytes = await file.read()
result = await process_image(image_bytes)
return {
"status": "success",
"result": result
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
六、实战案例:PCB板缺陷检测
1. 数据集构建与训练
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from sklearn.model_selection import train_test_split
# 1. 准备数据集
image_paths = glob.glob('dataset/pcb_images/*.jpg')
labels = [get_label_from_path(p) for p in image_paths] # 自定义标签获取函数
# 划分训练集和验证集
train_paths, val_paths, train_labels, val_labels = train_test_split(
image_paths, labels, test_size=0.2, stratify=labels
)
# 2. 创建数据加载器
train_dataset = DefectDataset(train_paths, train_labels, transform=ToTensor())
val_dataset = DefectDataset(val_paths, val_labels, transform=ToTensor())
train_loader = DataLoader(
train_dataset, batch_size=32, shuffle=True, num_workers=4
)
val_loader = DataLoader(
val_dataset, batch_size=16, shuffle=False, num_workers=2
)
# 3. 初始化模型和优化器
model = DefectDetector(num_classes=6).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = DefectLoss(alpha=0.7)
# 4. 训练循环
for epoch in range(50):
model.train()
train_loss = 0.0
for images, (cls_targets, loc_targets) in train_loader:
images = images.to(device)
cls_targets = cls_targets.to(device)
loc_targets = loc_targets.to(device)
optimizer.zero_grad()
cls_output, loc_output = model(images)
loss = criterion((cls_output, loc_output), (cls_targets, loc_targets))
loss.backward()
optimizer.step()
train_loss += loss.item()
# 验证循环
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, (cls_targets, loc_targets) in val_loader:
images = images.to(device)
cls_targets = cls_targets.to(device)
loc_targets = loc_targets.to(device)
cls_output, loc_output = model(images)
loss = criterion((cls_output, loc_output), (cls_targets, loc_targets))
val_loss += loss.item()
_, predicted = torch.max(cls_output.data, 1)
total += cls_targets.size(0)
correct += (predicted == cls_targets).sum().item()
print(f'Epoch {epoch+1}: '
f'Train Loss: {train_loss/len(train_loader):.4f}, '
f'Val Loss: {val_loss/len(val_loader):.4f}, '
f'Val Acc: {100*correct/total:.2f}%')
2. 产线集成方案
import cv2
import requests
from datetime import datetime
class ProductionLineIntegration:
def __init__(self, camera_index=0, api_endpoint="http://localhost:8000/predict"):
self.camera = cv2.VideoCapture(camera_index)
self.api_endpoint = api_endpoint
self.defect_count = 0
self.total_count = 0
def process_frame(self, frame):
"""处理单个帧"""
# 预处理
preprocessor = ImagePreprocessor()
tensor = preprocessor.process(frame)
# 转换为字节流
_, img_encoded = cv2.imencode('.jpg', frame)
img_bytes = img_encoded.tobytes()
# 调用API
files = {'file': ('frame.jpg', img_bytes, 'image/jpeg')}
response = requests.post(self.api_endpoint, files=files)
if response.status_code == 200:
result = response.json()
return result['result']
else:
raise Exception(f"API调用失败: {response.text}")
def run(self):
"""主运行循环"""
while True:
ret, frame = self.camera.read()
if not ret:
break
try:
result = self.process_frame(frame)
self.total_count += 1
if result['class'] != 0: # 0表示无缺陷
self.defect_count += 1
self.log_defect(frame, result)
# 显示结果
self.display_result(frame, result)
except Exception as e:
print(f"处理失败: {str(e)}")
if cv2.waitKey(1) & 0xFF == ord('q'):
break
print(f"检测完成 - 总数: {self.total_count}, 缺陷: {self.defect_count}")
self.camera.release()
cv2.destroyAllWindows()
def log_defect(self, frame, result):
"""记录缺陷信息"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"defects/defect_{timestamp}.jpg"
cv2.imwrite(filename, frame)
with open("defects/log.csv", "a") as f:
f.write(f"{timestamp},{result['class']},{result['confidence']}n")
def display_result(self, frame, result):
"""可视化显示结果"""
label = f"Class: {result['class']}, Confidence: {result['confidence']:.2f}"
cv2.putText(frame, label, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# 显示热力图
heatmap = np.array(result['heatmap'])
heatmap = cv2.resize(heatmap, (frame.shape[1], frame.shape[0]))
heatmap = (heatmap * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
blended = cv2.addWeighted(frame, 0.7, heatmap, 0.3, 0)
cv2.imshow('Defect Detection', blended)