Python+PyTorch构建智能工业缺陷检测系统实战:从图像处理到模型部署 | 工业AI解决方案

2025-08-10 0 428

基于深度学习的工业质检全流程开发指南

一、工业缺陷检测技术选型

主流工业视觉检测方案对比:

技术方案 准确率 处理速度 硬件成本
传统图像处理 70-85% 快(50-100ms)
机器学习 85-92% 中(100-300ms)
深度学习 93-99% 慢(300-1000ms)
混合方案 90-97% 中快(150-500ms) 中高

二、系统架构设计

1. 智能检测系统架构

图像采集 → 预处理 → 缺陷检测 → 分类定位 → 结果输出
    ↑           ↑           ↑           ↑           ↑
工业相机    OpenCV处理    PyTorch模型   后处理分析    MES系统对接
            

2. 模型训练流程

数据收集 → 标注清洗 → 增强扩充 → 模型训练 → 量化压缩 → 部署上线
    ↑           ↑           ↑           ↑           ↑           ↑
产线采集    Labelme工具    Albumentations  迁移学习   TensorRT优化   Flask服务

三、核心模块实现

1. 工业图像预处理

import cv2
import numpy as np
from albumentations import (
    Compose, Rotate, RandomBrightnessContrast, 
    GaussNoise, OpticalDistortion
)

class ImagePreprocessor:
    def __init__(self, target_size=(512, 512)):
        self.target_size = target_size
        self.augmentations = Compose([
            Rotate(limit=15, p=0.5),
            RandomBrightnessContrast(p=0.3),
            GaussNoise(var_limit=(10, 50), p=0.2),
            OpticalDistortion(p=0.1)
        ])
    
    def process(self, image_path, is_training=False):
        # 读取并标准化图像
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # 统一尺寸和归一化
        image = cv2.resize(image, self.target_size)
        image = image.astype(np.float32) / 255.0
        
        # 训练时数据增强
        if is_training:
            augmented = self.augmentations(image=image)
            image = augmented['image']
        
        # 标准化处理
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = (image - mean) / std
        
        # 转换为PyTorch张量格式
        image = np.transpose(image, (2, 0, 1))
        return torch.from_numpy(image)

    def adaptive_threshold(self, image):
        """自适应阈值处理"""
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        return cv2.adaptiveThreshold(
            gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY_INV, 11, 2
        )
    
    def remove_background(self, image):
        """基于形态学的背景去除"""
        kernel = np.ones((5,5), np.uint8)
        opening = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel, iterations=2)
        return cv2.dilate(opening, kernel, iterations=3)

2. 缺陷检测模型

import torch
import torch.nn as nn
from torchvision.models import resnet34

class DefectDetector(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        # 使用预训练的ResNet34作为骨干网络
        self.backbone = resnet34(pretrained=True)
        
        # 替换最后的全连接层
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()  # 移除原始全连接层
        
        # 添加自定义头部
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
        # 缺陷定位分支
        self.locator = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # 骨干网络特征提取
        features = self.backbone(x)
        
        # 分类预测
        cls_output = self.classifier(features)
        
        # 获取中间特征图用于定位
        mid_features = self.backbone.layer4(x)
        loc_output = self.locator(mid_features)
        
        return cls_output, loc_output

class DefectLoss(nn.Module):
    def __init__(self, alpha=0.7):
        super().__init__()
        self.alpha = alpha
        self.cls_loss = nn.CrossEntropyLoss()
        self.loc_loss = nn.BCELoss()
    
    def forward(self, outputs, targets):
        cls_pred, loc_pred = outputs
        cls_target, loc_target = targets
        
        # 分类损失
        loss_cls = self.cls_loss(cls_pred, cls_target)
        
        # 定位损失
        loss_loc = self.loc_loss(loc_pred, loc_target)
        
        # 组合损失
        return self.alpha * loss_cls + (1 - self.alpha) * loss_loc

四、高级功能实现

1. 小样本学习增强

import torch
from torch.utils.data import Dataset
from sklearn.utils.class_weight import compute_class_weight

class FewShotDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
        # 计算类别权重
        self.class_weights = compute_class_weight(
            'balanced',
            classes=np.unique(labels),
            y=labels
        )
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

class FewShotSampler:
    def __init__(self, dataset, n_way=5, k_shot=3, query_num=5):
        self.dataset = dataset
        self.n_way = n_way  # 每批N个类别
        self.k_shot = k_shot  # 每类K个样本
        self.query_num = query_num  # 查询样本数
        
        # 按类别组织样本索引
        self.class_indices = {}
        for idx, (_, label) in enumerate(dataset):
            if label not in self.class_indices:
                self.class_indices[label] = []
            self.class_indices[label].append(idx)
        
        self.classes = list(self.class_indices.keys())
    
    def __iter__(self):
        while True:
            # 随机选择N个类别
            selected_classes = np.random.choice(
                self.classes, self.n_way, replace=False)
            
            support_set = []
            query_set = []
            
            for cls in selected_classes:
                # 从当前类中随机选择K+Q个样本
                samples = np.random.choice(
                    self.class_indices[cls],
                    self.k_shot + self.query_num,
                    replace=False)
                
                # 前K个作为支持集
                support_set.extend(samples[:self.k_shot])
                # 后Q个作为查询集
                query_set.extend(samples[self.k_shot:])
            
            yield support_set + query_set

2. 模型解释性可视化

import matplotlib.pyplot as plt
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz

class ModelExplainer:
    def __init__(self, model):
        self.model = model
        self.ig = IntegratedGradients(model)
    
    def explain(self, input_tensor, target_class, n_steps=50):
        # 计算归因
        attributions = self.ig.attribute(
            input_tensor,
            target=target_class,
            n_steps=n_steps
        )
        
        # 转换为适合可视化的格式
        img = input_tensor.squeeze().permute(1, 2, 0).cpu().detach().numpy()
        attr = attributions.squeeze().permute(1, 2, 0).cpu().detach().numpy()
        
        # 可视化
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        
        # 原始图像
        ax[0].imshow(img)
        ax[0].set_title('Original Image')
        ax[0].axis('off')
        
        # 归因热力图
        viz.visualize_image_attr(
            np.transpose(attr, (2, 0, 1)),
            original_image=img,
            method='blended_heat_map',
            sign='absolute_value',
            show_colorbar=True,
            title='Feature Importance',
            plt_fig_axis=(fig, ax[1])
        )
        
        return fig
    
    def generate_saliency_map(self, input_tensor, target_class):
        """生成显著性图"""
        input_tensor.requires_grad_()
        
        # 前向传播
        output = self.model(input_tensor.unsqueeze(0))
        output[0, target_class].backward()
        
        # 获取梯度
        saliency = input_tensor.grad.abs().max(dim=0)[0]
        
        # 可视化
        plt.figure(figsize=(6, 6))
        plt.imshow(saliency.cpu(), cmap='hot')
        plt.colorbar()
        plt.title('Saliency Map')
        plt.axis('off')
        
        return saliency

五、生产部署优化

1. 模型量化与加速

import torch
from torch.quantization import quantize_dynamic
from torch.utils.mobile_optimizer import optimize_for_mobile

class ModelOptimizer:
    def __init__(self, model):
        self.model = model.eval()
    
    def dynamic_quantization(self):
        """动态量化模型"""
        quantized_model = quantize_dynamic(
            self.model,
            {torch.nn.Linear, torch.nn.Conv2d},
            dtype=torch.qint8
        )
        return quantized_model
    
    def static_quantization(self, calibration_data):
        """静态量化模型"""
        quantized_model = torch.quantization.quantize_qat(
            self.model,
            {torch.nn.Linear, torch.nn.Conv2d},
            inplace=False
        )
        
        # 校准
        with torch.no_grad():
            for data in calibration_data:
                _ = quantized_model(data)
        
        return torch.quantization.convert(quantized_model)
    
    def optimize_for_mobile(self, output_path):
        """优化移动端部署"""
        scripted_model = torch.jit.script(self.model)
        optimized_model = optimize_for_mobile(scripted_model)
        optimized_model._save_for_lite_interpreter(output_path)
        return optimized_model
    
    def convert_to_onnx(self, output_path, input_shape=(1, 3, 512, 512)):
        """转换为ONNX格式"""
        dummy_input = torch.randn(*input_shape)
        torch.onnx.export(
            self.model,
            dummy_input,
            output_path,
            export_params=True,
            opset_version=11,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size'},
                'output': {0: 'batch_size'}
            }
        )

    def tensorrt_optimization(self, onnx_path, output_path):
        """使用TensorRT优化"""
        import tensorrt as trt
        
        logger = trt.Logger(trt.Logger.WARNING)
        builder = trt.Builder(logger)
        network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
        parser = trt.OnnxParser(network, logger)
        
        # 解析ONNX模型
        with open(onnx_path, 'rb') as f:
            if not parser.parse(f.read()):
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                raise ValueError("ONNX解析失败")
        
        # 构建配置
        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB
        
        # 构建引擎
        serialized_engine = builder.build_serialized_network(network, config)
        with open(output_path, 'wb') as f:
            f.write(serialized_engine)

2. 高性能推理服务

from fastapi import FastAPI, File, UploadFile
import uvicorn
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import io

app = FastAPI()
model = None
executor = ThreadPoolExecutor(max_workers=4)

def load_model():
    """加载预训练模型"""
    global model
    model = DefectDetector()
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()
    model.to('cuda' if torch.cuda.is_available() else 'cpu')

@app.on_event("startup")
async def startup_event():
    load_model()

async def process_image(image_bytes):
    """异步处理图像"""
    loop = asyncio.get_event_loop()
    
    # 在单独线程中运行CPU密集型任务
    image = await loop.run_in_executor(
        executor,
        preprocess_image,
        image_bytes
    )
    
    # 在事件循环中运行GPU推理
    with torch.no_grad():
        image = image.to('cuda' if torch.cuda.is_available() else 'cpu')
        cls_output, loc_output = model(image.unsqueeze(0))
    
    return {
        'class': torch.argmax(cls_output).item(),
        'confidence': torch.max(torch.softmax(cls_output, dim=1)).item(),
        'heatmap': loc_output.squeeze().cpu().numpy().tolist()
    }

def preprocess_image(image_bytes):
    """图像预处理"""
    image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
    image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    
    preprocessor = ImagePreprocessor()
    tensor = preprocessor.process(image)
    return tensor

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """预测API端点"""
    try:
        image_bytes = await file.read()
        result = await process_image(image_bytes)
        
        return {
            "status": "success",
            "result": result
        }
    except Exception as e:
        return {
            "status": "error",
            "message": str(e)
        }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

六、实战案例:PCB板缺陷检测

1. 数据集构建与训练

from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from sklearn.model_selection import train_test_split

# 1. 准备数据集
image_paths = glob.glob('dataset/pcb_images/*.jpg')
labels = [get_label_from_path(p) for p in image_paths]  # 自定义标签获取函数

# 划分训练集和验证集
train_paths, val_paths, train_labels, val_labels = train_test_split(
    image_paths, labels, test_size=0.2, stratify=labels
)

# 2. 创建数据加载器
train_dataset = DefectDataset(train_paths, train_labels, transform=ToTensor())
val_dataset = DefectDataset(val_paths, val_labels, transform=ToTensor())

train_loader = DataLoader(
    train_dataset, batch_size=32, shuffle=True, num_workers=4
)
val_loader = DataLoader(
    val_dataset, batch_size=16, shuffle=False, num_workers=2
)

# 3. 初始化模型和优化器
model = DefectDetector(num_classes=6).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = DefectLoss(alpha=0.7)

# 4. 训练循环
for epoch in range(50):
    model.train()
    train_loss = 0.0
    
    for images, (cls_targets, loc_targets) in train_loader:
        images = images.to(device)
        cls_targets = cls_targets.to(device)
        loc_targets = loc_targets.to(device)
        
        optimizer.zero_grad()
        
        cls_output, loc_output = model(images)
        loss = criterion((cls_output, loc_output), (cls_targets, loc_targets))
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    # 验证循环
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, (cls_targets, loc_targets) in val_loader:
            images = images.to(device)
            cls_targets = cls_targets.to(device)
            loc_targets = loc_targets.to(device)
            
            cls_output, loc_output = model(images)
            loss = criterion((cls_output, loc_output), (cls_targets, loc_targets))
            
            val_loss += loss.item()
            _, predicted = torch.max(cls_output.data, 1)
            total += cls_targets.size(0)
            correct += (predicted == cls_targets).sum().item()
    
    print(f'Epoch {epoch+1}: '
          f'Train Loss: {train_loss/len(train_loader):.4f}, '
          f'Val Loss: {val_loss/len(val_loader):.4f}, '
          f'Val Acc: {100*correct/total:.2f}%')

2. 产线集成方案

import cv2
import requests
from datetime import datetime

class ProductionLineIntegration:
    def __init__(self, camera_index=0, api_endpoint="http://localhost:8000/predict"):
        self.camera = cv2.VideoCapture(camera_index)
        self.api_endpoint = api_endpoint
        self.defect_count = 0
        self.total_count = 0
    
    def process_frame(self, frame):
        """处理单个帧"""
        # 预处理
        preprocessor = ImagePreprocessor()
        tensor = preprocessor.process(frame)
        
        # 转换为字节流
        _, img_encoded = cv2.imencode('.jpg', frame)
        img_bytes = img_encoded.tobytes()
        
        # 调用API
        files = {'file': ('frame.jpg', img_bytes, 'image/jpeg')}
        response = requests.post(self.api_endpoint, files=files)
        
        if response.status_code == 200:
            result = response.json()
            return result['result']
        else:
            raise Exception(f"API调用失败: {response.text}")
    
    def run(self):
        """主运行循环"""
        while True:
            ret, frame = self.camera.read()
            if not ret:
                break
            
            try:
                result = self.process_frame(frame)
                self.total_count += 1
                
                if result['class'] != 0:  # 0表示无缺陷
                    self.defect_count += 1
                    self.log_defect(frame, result)
                
                # 显示结果
                self.display_result(frame, result)
                
            except Exception as e:
                print(f"处理失败: {str(e)}")
            
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        
        print(f"检测完成 - 总数: {self.total_count}, 缺陷: {self.defect_count}")
        self.camera.release()
        cv2.destroyAllWindows()
    
    def log_defect(self, frame, result):
        """记录缺陷信息"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"defects/defect_{timestamp}.jpg"
        cv2.imwrite(filename, frame)
        
        with open("defects/log.csv", "a") as f:
            f.write(f"{timestamp},{result['class']},{result['confidence']}n")
    
    def display_result(self, frame, result):
        """可视化显示结果"""
        label = f"Class: {result['class']}, Confidence: {result['confidence']:.2f}"
        cv2.putText(frame, label, (10, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
        
        # 显示热力图
        heatmap = np.array(result['heatmap'])
        heatmap = cv2.resize(heatmap, (frame.shape[1], frame.shape[0]))
        heatmap = (heatmap * 255).astype(np.uint8)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        
        blended = cv2.addWeighted(frame, 0.7, heatmap, 0.3, 0)
        cv2.imshow('Defect Detection', blended)
Python+PyTorch构建智能工业缺陷检测系统实战:从图像处理到模型部署 | 工业AI解决方案
收藏 (0) 打赏

感谢您的支持,我会继续努力的!

打开微信/支付宝扫一扫,即可进行扫码打赏哦,分享从这里开始,精彩与您同在
点赞 (0)

淘吗网 python Python+PyTorch构建智能工业缺陷检测系统实战:从图像处理到模型部署 | 工业AI解决方案 https://www.taomawang.com/server/python/795.html

常见问题

相关文章

发表评论
暂无评论
官方客服团队

为您解决烦忧 - 24小时在线 专业服务