PyTorch 深度学习实战(32):多模态学习与CLIP模型
zhezhongyun 2025-04-27 17:32 41 浏览
在上一篇文章中,我们探讨了可解释性AI与特征可视化技术。本文将深入多模态学习领域,重点介绍OpenAI提出的CLIP(Contrastive Language-Image Pretraining)模型,该模型通过对比学习实现了图像与文本的联合理解。
一、多模态学习基础
1. 核心概念
- 模态对齐:建立不同模态(如图像/文本)间的语义关联
- 跨模态检索:实现图文双向搜索
- 联合表征:学习统一的特征空间
2. 技术对比
方法 | 代表模型 | 特点 | 典型应用 |
双塔结构 | CLIP | 对比学习预训练 | 零样本分类 |
融合编码器 | ViLBERT | 跨模态注意力机制 | 视觉问答 |
生成式架构 | DALL·E | 文本到图像生成 | 创意内容生成 |
统一Transformer | Flamingo | 处理交错图文序列 | 多模态对话 |
二、CLIP模型原理
1. 对比学习目标
CLIP通过优化图像-文本对的相似度矩阵:
2. 模型架构
import torch
from torch import nn
from typing import Tuple, Optional
import torch.nn.functional as F
class CLIP(nn.Module):
def __init__(
self,
image_encoder: nn.Module,
text_encoder: nn.Module,
embed_dim: int = 512,
init_logit_scale: float = 2.6592,
projection_dropout: float = 0.1
):
"""
CLIP模型实现
参数:
image_encoder: 图像编码器 (需有output_dim属性)
text_encoder: 文本编码器 (需有output_dim属性)
embed_dim: 联合嵌入空间的维度
init_logit_scale: 初始温度参数
projection_dropout: 投影层的dropout率
"""
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
# 图像/文本投影层
self.image_proj = nn.Sequential(
nn.Linear(image_encoder.output_dim, embed_dim),
nn.Dropout(projection_dropout)
)
self.text_proj = nn.Sequential(
nn.Linear(text_encoder.output_dim, embed_dim),
nn.Dropout(projection_dropout)
)
# 可学习的温度参数 (logit scale)
self.logit_scale = nn.Parameter(torch.tensor([init_logit_scale]))
# 初始化
self._init_weights()
def _init_weights(self):
"""初始化投影层权重"""
for proj in [self.image_proj, self.text_proj]:
if isinstance(proj[0], nn.Linear):
nn.init.normal_(proj[0].weight, std=0.02)
if proj[0].bias is not None:
nn.init.zeros_(proj[0].bias)
def encode_image(self, image: torch.Tensor) -> torch.Tensor:
"""提取归一化的图像特征"""
image_features = self.image_proj(self.image_encoder(image))
return image_features / image_features.norm(dim=1, keepdim=True)
def encode_text(self, text: torch.Tensor) -> torch.Tensor:
"""提取归一化的文本特征"""
text_features = self.text_proj(self.text_encoder(text))
return text_features / text_features.norm(dim=1, keepdim=True)
def forward(
self,
image: torch.Tensor,
text: torch.Tensor,
return_features: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
前向传播
参数:
image: 输入图像张量 [batch, channels, H, W]
text: 输入文本张量 [batch, seq_len]
return_features: 是否返回原始特征
返回:
logits: 图像-文本相似度矩阵 [batch, batch]
(可选) image_features: 图像特征 [batch, embed_dim]
(可选) text_features: 文本特征 [batch, embed_dim]
"""
# 提取特征
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# 计算相似度
logit_scale = self.logit_scale.exp().clamp(max=100) # 防止数值溢出
logits = logit_scale * image_features @ text_features.t()
if return_features:
return logits, image_features, text_features
return logits
def compute_loss(
self,
image_features: torch.Tensor,
text_features: torch.Tensor
) -> torch.Tensor:
"""
计算对称对比损失
参数:
image_features: 归一化的图像特征 [batch, embed_dim]
text_features: 归一化的文本特征 [batch, embed_dim]
返回:
损失值 (标量张量)
"""
logit_scale = self.logit_scale.exp().clamp(max=100)
# 计算相似度矩阵
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# 创建标签 (对角线为匹配对)
batch_size = image_features.shape[0]
labels = torch.arange(batch_size, device=image_features.device)
# 对称损失
loss_image = F.cross_entropy(logits_per_image, labels)
loss_text = F.cross_entropy(logits_per_text, labels)
return (loss_image + loss_text) / 2
# 示例用法
if __name__ == "__main__":
# 模拟编码器 (实际应使用ViT/Transformer等)
class MockEncoder(nn.Module):
def __init__(self, output_dim=768):
super().__init__()
self.output_dim = output_dim
self.proj = nn.Linear(1000, output_dim)
def forward(self, x):
return self.proj(torch.randn(x.shape[0], 1000).to(x.device))
# 初始化CLIP
image_encoder = MockEncoder()
text_encoder = MockEncoder()
clip_model = CLIP(image_encoder, text_encoder)
# 模拟输入
batch_size = 4
fake_images = torch.randn(batch_size, 3, 224, 224)
fake_texts = torch.randint(0, 10000, (batch_size, 77))
# 前向传播
logits, img_feats, txt_feats = clip_model(fake_images, fake_texts, return_features=True)
print(f"相似度矩阵形状: {logits.shape}")
print(f"图像特征形状: {img_feats.shape}")
print(f"文本特征形状: {txt_feats.shape}")
# 计算损失
loss = clip_model.compute_loss(img_feats, txt_feats)
print(f"对比损失值: {loss.item():.4f}")输出为:
相似度矩阵形状: torch.Size([4, 4])
图像特征形状: torch.Size([4, 512])
文本特征形状: torch.Size([4, 512])
对比损失值: 1.6367三、CLIP实战应用
1. 使用官方预训练模型
import clip
import torch
from PIL import Image
# 加载模型与预处理
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# 图像-文本匹配
image = preprocess(Image.open("cat.jpeg")).unsqueeze(0).to(device)
text = clip.tokenize(["a cat", "a dog", "a bird"]).to(device)
with torch.no_grad():
logits_per_image, _ = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("预测概率:", probs) #预测概率: [[0.9785 0.01087 0.010704]]2. 零样本图像分类
import torch
import clip
from PIL import Image
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple
class ZeroShotCLIPClassifier:
def __init__(self,
model_name: str = "ViT-B/32",
device: Optional[str] = None):
"""
初始化CLIP零样本分类器
参数:
model_name: CLIP模型名称 (e.g. "ViT-B/32", "RN50")
device: 指定设备 (None则自动选择)
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model, self.preprocess = clip.load(model_name, device=self.device)
self.model.eval()
def predict(
self,
image_path: str,
class_descriptions: List[str],
temperature: float = 100.0,
show_visualization: bool = True
) -> Tuple[str, torch.Tensor]:
"""
执行零样本分类
参数:
image_path: 图像文件路径
class_descriptions: 类别描述列表
temperature: 温度参数控制置信度分布
show_visualization: 是否显示分类结果可视化
返回:
tuple: (预测类别, 各类别概率)
"""
try:
# 1. 图像预处理
image = self._load_and_preprocess(image_path)
# 2. 文本tokenize
text_inputs = self._prepare_text(class_descriptions)
# 3. 特征提取
with torch.no_grad():
image_features = self.model.encode_image(image)
text_features = self.model.encode_text(text_inputs)
# 4. 计算相似度
logits = (temperature * image_features @ text_features.T)
probs = logits.softmax(dim=-1).squeeze()
# 5. 结果处理
pred_idx = probs.argmax().item()
pred_class = class_descriptions[pred_idx]
if show_visualization:
self._visualize_results(image_path, class_descriptions, probs.cpu())
return pred_class, probs
except Exception as e:
raise RuntimeError(f"分类失败: {str(e)}") from e
def _load_and_preprocess(self, image_path: str) -> torch.Tensor:
"""加载并预处理图像"""
try:
image = Image.open(image_path)
return self.preprocess(image).unsqueeze(0).to(self.device)
except FileNotFoundError:
raise ValueError(f"图像文件不存在: {image_path}")
except Exception as e:
raise RuntimeError(f"图像加载失败: {str(e)}")
def _prepare_text(self, descriptions: List[str]) -> torch.Tensor:
"""准备文本输入"""
if not descriptions:
raise ValueError("类别描述列表不能为空")
return torch.cat([clip.tokenize(desc) for desc in descriptions]).to(self.device)
def _visualize_results(
self,
image_path: str,
classes: List[str],
probs: torch.Tensor
) -> None:
"""可视化分类结果"""
plt.figure(figsize=(12, 6))
# 显示图像
plt.subplot(1, 2, 1)
image = Image.open(image_path)
plt.imshow(image)
plt.axis('off')
plt.title('Input Image')
# 显示分类概率
plt.subplot(1, 2, 2)
colors = plt.cm.viridis(probs.numpy() / probs.max())
bars = plt.barh(classes, probs.numpy(), color=colors)
plt.xlabel('Probability')
plt.title('Classification Probabilities')
plt.gca().invert_yaxis() # 最高概率显示在最上方
# 添加概率值标签
for bar in bars:
width = bar.get_width()
plt.text(width + 0.01, bar.get_y() + bar.get_height()/2,
f'{width:.2f}',
va='center')
plt.tight_layout()
plt.show()
# 使用示例
if __name__ == "__main__":
# 初始化分类器
classifier = ZeroShotCLIPClassifier(model_name="ViT-B/32")
# 定义类别描述 (可自由扩展)
animal_classes = [
"a photo of a cat",
"a photo of a dog",
"a photo of a bird",
"a photo of a horse",
"a photo of a fish"
]
# 执行分类
image_path = "cat.jpeg" # 替换为你的图像路径
pred_class, probs = classifier.predict(
image_path=image_path,
class_descriptions=animal_classes,
temperature=100.0,
show_visualization=True
)
print(f"\n预测结果: {pred_class}")
print("各类别概率:")
for cls, prob in zip(animal_classes, probs):
print(f"- {cls}: {prob.item():.4f}")输出为:
预测结果: a photo of a cat
各类别概率:
- a photo of a cat: 1.0000
- a photo of a dog: 0.0000
- a photo of a bird: 0.0000
- a photo of a horse: 0.0000
- a photo of a fish: 0.00003. 特征空间可视化
import torch
import umap
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from typing import List, Optional, Tuple
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from sklearn.preprocessing import StandardScaler
class MultimodalVisualizer:
def __init__(self,
model,
preprocess,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
n_neighbors: int = 15,
min_dist: float = 0.1,
metric: str = 'cosine',
random_state: int = 42):
"""
参数:
model: 已加载的CLIP模型
preprocess: CLIP预处理函数
device: 指定计算设备
n_neighbors: UMAP邻居数
min_dist: UMAP点间最小距离
metric: 距离度量方式
random_state: 随机种子
"""
self.model = model
self.preprocess = preprocess
self.device = device
self.model.to(self.device) # 确保模型在正确设备上
self.reducer = umap.UMAP(
n_neighbors=n_neighbors,
min_dist=min_dist,
metric=metric,
random_state=random_state
)
self.scaler = StandardScaler()
def visualize_embeddings(self, image_paths: List[str], texts: List[str], **kwargs):
"""可视化入口方法"""
# 提取特征
image_embeddings, text_embeddings = self._extract_features(image_paths, texts)
# 合并特征并标准化
all_embeddings = torch.cat([image_embeddings, text_embeddings]).cpu().numpy()
scaled_embeddings = self.scaler.fit_transform(all_embeddings)
# 降维可视化
return self._plot_embeddings(
scaled_embeddings,
len(image_paths),
image_paths,
texts,
**kwargs
)
def _extract_features(self, image_paths, texts):
"""特征提取方法"""
# 图像特征
image_features = []
for path in image_paths:
try:
image = Image.open(path)
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
features = self.model.encode_image(image_input)
image_features.append(features)
except Exception as e:
print(f"跳过图像 {path}: {str(e)}")
continue
# 文本特征
text_inputs = torch.cat([
clip.tokenize(txt) for txt in texts
]).to(self.device) # 显式指定设备
with torch.no_grad():
text_features = self.model.encode_text(text_inputs)
return torch.cat(image_features), text_features
def _plot_embeddings(self, embeddings, n_images, image_paths, texts, **kwargs):
"""可视化绘图方法"""
# 参数设置
figsize = kwargs.get('figsize', (15, 10))
point_size = kwargs.get('point_size', 50)
sample_images = kwargs.get('sample_images', 5)
# 创建图表
fig, ax = plt.subplots(figsize=figsize)
# 绘制图像点
img_scatter = ax.scatter(
embeddings[:n_images, 0], embeddings[:n_images, 1],
c='blue', label='Images', s=point_size, alpha=0.5
)
# 绘制文本点
txt_scatter = ax.scatter(
embeddings[n_images:, 0], embeddings[n_images:, 1],
c='red', label='Texts', s=point_size, alpha=0.7
)
# 添加交互元素
self._add_interactive_elements(ax, embeddings, n_images, image_paths, texts, sample_images)
# 美化图表
ax.set_title('CLIP Multimodal Embedding Space', pad=20)
ax.legend()
plt.tight_layout()
return fig
def _add_interactive_elements(self, ax, embeddings, n_images, image_paths, texts, sample_images):
"""添加交互元素"""
# 添加文本标签
for i in range(n_images, len(embeddings)):
ax.annotate(
texts[i-n_images][:15] + "..." if len(texts[i-n_images]) > 15 else texts[i-n_images],
(embeddings[i, 0], embeddings[i, 1]),
fontsize=8, alpha=0.8
)
# 添加缩略图
step = max(1, n_images // sample_images)
for i in range(0, n_images, step):
try:
img = Image.open(image_paths[i])
img.thumbnail((100, 100))
im = OffsetImage(img, zoom=0.5)
ab = AnnotationBbox(
im, (embeddings[i, 0], embeddings[i, 1]),
frameon=False, pad=0
)
ax.add_artist(ab)
except Exception as e:
print(f"无法加载缩略图 {image_paths[i]}: {str(e)}")
# 使用示例
if __name__ == "__main__":
import clip
# 初始化CLIP模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# 准备数据
image_paths = [
"cat.jpeg",
"dog.jpg",
"bird.jpeg",
"car.jpg",
"building.jpg"
]
texts = [
"a photo of a cat",
"a picture of a dog",
"a bird flying in the sky",
"a red car on the road",
"a modern office building"
]
# 创建可视化
visualizer = MultimodalVisualizer(model, preprocess, device=device)
fig = visualizer.visualize_embeddings(
image_paths=image_paths,
texts=texts,
sample_images=2,
point_size=80
)
plt.savefig("multimodal-embedding-space.png")
plt.show()输出为:
四、自定义CLIP训练
1. 数据准备
import torch
from torch.utils.data import Dataset
from PIL import Image
import clip
from typing import List, Callable, Optional
import numpy as np
import os
class ImageTextDataset(Dataset):
def __init__(
self,
image_paths: List[str],
texts: List[str],
transform: Optional[Callable] = None,
preload_images: bool = False,
max_text_length: int = 77,
tokenizer: Callable = clip.tokenize,
retry_on_error: int = 3
):
"""
多模态图像-文本数据集
参数:
image_paths: 图像路径列表
texts: 对应文本描述列表
transform: 图像预处理函数
preload_images: 是否预加载图像到内存
max_text_length: 文本最大token长度
tokenizer: 文本tokenizer函数
retry_on_error: 错误重试次数
"""
assert len(image_paths) == len(texts), "图像和文本数量必须相同"
self.image_paths = image_paths
self.texts = texts
self.transform = transform
self.tokenizer = tokenizer
self.max_text_length = max_text_length
self.retry_on_error = retry_on_error
# 预加载选项
self.preloaded = None
if preload_images:
self._preload_images()
def _preload_images(self):
"""将图像预加载到内存"""
self.preloaded = []
for path in self.image_paths:
for _ in range(self.retry_on_error + 1):
try:
img = Image.open(path).convert('RGB')
self.preloaded.append(img)
break
except Exception as e:
if _ == self.retry_on_error:
print(f"无法加载图像 {path}: {str(e)}")
self.preloaded.append(None)
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> tuple:
"""
返回:
tuple: (图像张量, 文本token)
如果加载失败且未预加载,返回 (None, None)
"""
# 文本处理
text = self.texts[idx]
text_tokens = self.tokenizer(text, truncate=True)[0] # 自动截断
# 图像处理
for attempt in range(self.retry_on_error + 1):
try:
if self.preloaded is not None:
img = self.preloaded[idx]
if img is None: # 预加载时已失败
return None, None
else:
img = Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
img = self.transform(img)
return img, text_tokens
except Exception as e:
if attempt == self.retry_on_error:
print(f"加载失败 {self.image_paths[idx]}: {str(e)}")
if self.preloaded is not None:
self.preloaded[idx] = None # 标记为失败
return None, None
def get_valid_samples(self) -> 'ImageTextDataset':
"""获取有效样本的子数据集"""
valid_indices = []
for i in range(len(self)):
img_path = self.image_paths[i]
if self.preloaded and self.preloaded[i] is None:
continue
if not os.path.exists(img_path):
continue
valid_indices.append(i)
return ImageTextDataset(
image_paths=[self.image_paths[i] for i in valid_indices],
texts=[self.texts[i] for i in valid_indices],
transform=self.transform,
preload_images=False, # 不再重复预加载
max_text_length=self.max_text_length,
tokenizer=self.tokenizer
)
# 使用示例
if __name__ == "__main__":
import clip
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
# 1. 初始化CLIP预处理
device = "cuda" if torch.cuda.is_available() else "cpu"
_, preprocess = clip.load("ViT-B/32", device=device)
# 2. 自定义预处理管道
custom_transform = Compose([
Resize(256),
CenterCrop(224),
lambda x: x.convert("RGB"), # 确保RGB格式
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))
])
# 3. 创建数据集
dataset = ImageTextDataset(
image_paths=["cat.jpeg", "dog.jpg", "nonexistent.jpg"],
texts=["a cute cat", "a happy dog", "missing image"],
transform=custom_transform,
preload_images=True,
retry_on_error=2
)
# 4. 过滤无效样本
valid_dataset = dataset.get_valid_samples()
print(f"原始样本数: {len(dataset)} | 有效样本数: {len(valid_dataset)}")
# 5. 数据加载示例
from torch.utils.data import DataLoader
def collate_fn(batch):
# 过滤掉无效样本 (None, None)
batch = [item for item in batch if item[0] is not None]
if len(batch) == 0:
return None
images, texts = zip(*batch)
return torch.stack(images), torch.stack(texts)
dataloader = DataLoader(
valid_dataset,
batch_size=2,
shuffle=True,
collate_fn=collate_fn,
num_workers=4,
pin_memory=True
)
# 6. 测试迭代
for batch_idx, (images, texts) in enumerate(dataloader):
print(f"Batch {batch_idx}:")
print(f"- 图像形状: {images.shape}")
print(f"- 文本形状: {texts.shape}")
if batch_idx >= 1: # 只展示前两个batch
break输出为:
无法加载图像 nonexistent.jpg: [Errno 2] No such file or directory: '/workspace/nonexistent.jpg'
原始样本数: 3 | 有效样本数: 2
Batch 0:
- 图像形状: torch.Size([2, 3, 224, 224])
- 文本形状: torch.Size([2, 77])2. 训练循环
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import logging
from datetime import datetime
from torch.utils.data import DataLoader
from torchvision import transforms
def setup_logger():
"""设置基础日志配置"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
# 将日志输出到文件(文件名包含当前时间)
logging.FileHandler(f'clip_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
# 同时输出到控制台
logging.StreamHandler()
]
)
def train_clip(model, train_loader, val_loader=None, epochs=5, device='cuda', save_path='best_clip_model.pth'):
"""
使用对比学习训练CLIP模型
参数:
model: 要训练的CLIP模型(应返回图像和文本的嵌入向量)
train_loader: 训练数据的DataLoader
val_loader: 可选,验证数据的DataLoader
epochs: 训练轮数
device: 训练设备 ('cuda' 或 'cpu')
save_path: 最佳模型保存路径
"""
setup_logger()
logger = logging.getLogger(__name__)
# 将模型移动到指定设备
model = model.to(device)
# 设置优化器和学习率调度器
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01) # 使用权重衰减防止过拟合
scheduler = CosineAnnealingLR(optimizer, T_max=epochs * len(train_loader)) # 余弦退火学习率
# 跟踪最佳验证损失
best_loss = float('inf')
for epoch in range(epochs):
model.train() # 设置为训练模式
total_loss = 0.0
# 使用进度条显示训练过程
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}', leave=False)
for batch_idx, (images, texts) in enumerate(progress_bar):
# 将数据移动到设备
images, texts = images.to(device), texts.to(device)
# 前向传播:获取图像和文本特征
image_features, text_features = model(images, texts)
# 特征归一化(重要步骤)
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# 计算相似度矩阵(使用可学习的温度参数logit_scale)
logit_scale = model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t() # 图像-文本相似度
logits_per_text = logits_per_image.t() # 文本-图像相似度
# 计算对比损失
labels = torch.arange(len(images), device=device) # 创建对角线标签
loss = (F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)) / 2 # 对称损失
# 反向传播
optimizer.zero_grad() # 清空梯度
loss.backward() # 计算梯度
# 梯度裁剪(防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() # 更新参数
scheduler.step() # 更新学习率
total_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()}) # 在进度条显示当前损失
# 计算平均训练损失
avg_train_loss = total_loss / len(train_loader)
logger.info(f"Epoch {epoch + 1}/{epochs} - 训练损失: {avg_train_loss:.4f}")
# 验证阶段
if val_loader is not None:
val_loss = evaluate(model, val_loader, device)
logger.info(f"Epoch {epoch + 1}/{epochs} - 验证损失: {val_loss:.4f}")
# 保存最佳模型
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), save_path)
logger.info(f"保存新的最佳模型,验证损失: {val_loss:.4f}")
return model
def evaluate(model, data_loader, device='cuda'):
"""在验证数据上评估模型"""
model.eval() # 设置为评估模式
total_loss = 0.0
with torch.no_grad(): # 禁用梯度计算
for images, texts in data_loader:
images, texts = images.to(device), texts.to(device)
# 获取特征并归一化
image_features, text_features = model(images, texts)
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# 计算相似度矩阵
logit_scale = model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# 计算对比损失
labels = torch.arange(len(images), device=device)
loss = (F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)) / 2
total_loss += loss.item()
# 返回平均验证损失
return total_loss / len(data_loader)
# 1. 定义一个简单的CLIP模型结构(示例)
class SimpleCLIP(nn.Module):
def __init__(self, image_embed_dim=512, text_embed_dim=512):
super().__init__()
# 图像编码器(使用简化的CNN)
self.image_encoder = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, image_embed_dim)
)
# 文本编码器(使用简化的LSTM)
self.text_encoder = nn.LSTM(
input_size=300, # 假设词向量维度为300
hidden_size=text_embed_dim,
num_layers=2,
batch_first=True)
# 可学习的温度参数(logit_scale)
self.logit_scale = nn.Parameter(torch.ones([]) * 0.07)
def forward(self, images, texts):
# 图像特征提取
image_features = self.image_encoder(images)
# 文本特征提取(假设texts是预处理的词向量序列)
_, (hidden, _) = self.text_encoder(texts)
text_features = hidden[-1] # 取最后一层的隐藏状态
return image_features, text_features
# 2. 准备模拟数据集(实际使用时替换为真实数据集)
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, size=1000):
self.size = size
# 模拟图像数据(3通道,224x224)
self.images = torch.randn(size, 3, 224, 224)
# 模拟文本数据(假设已经转换为词向量序列,长度20,维度300)
self.texts = torch.randn(size, 20, 300)
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.images[idx], self.texts[idx]
# 3. 数据预处理和加载
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 创建数据集和数据加载器
train_dataset = DummyDataset(size=1000)
val_dataset = DummyDataset(size=200)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 4. 初始化模型并训练
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleCLIP().to(device)
# 调用训练函数
trained_model = train_clip(
model=model,
train_loader=train_loader,
val_loader=val_loader,
epochs=5,
device=device,
save_path='best_clip_model.pth'
)
# 5. 使用训练好的模型(示例)
def encode_image(model, image):
"""编码单张图像"""
model.eval()
with torch.no_grad():
image = image.unsqueeze(0).to(device) # 添加batch维度
features = model.image_encoder(image)
return F.normalize(features, dim=-1)
def encode_text(model, text):
"""编码单个文本"""
model.eval()
with torch.no_grad():
text = text.unsqueeze(0).to(device) # 添加batch维度
_, (hidden, _) = model.text_encoder(text)
features = hidden[-1]
return F.normalize(features, dim=-1)
# 示例使用
test_image = torch.randn(3, 224, 224) # 模拟测试图像
test_text = torch.randn(20, 300) # 模拟测试文本
image_feature = encode_image(trained_model, test_image)
text_feature = encode_text(trained_model, test_text)
# 计算相似度
similarity = (image_feature @ text_feature.T) * trained_model.logit_scale.exp()
print(f"图像-文本相似度: {similarity.item():.4f}")输出为:
2025-04-02 02:24:47,144 - INFO - Epoch 1/5 - 训练损失: 3.4226
2025-04-02 02:24:47,216 - INFO - Epoch 1/5 - 验证损失: 3.2677
2025-04-02 02:24:47,238 - INFO - 保存新的最佳模型,验证损失: 3.2677
2025-04-02 02:24:47,935 - INFO - Epoch 2/5 - 训练损失: 3.4223
2025-04-02 02:24:48,016 - INFO - Epoch 2/5 - 验证损失: 3.2677
2025-04-02 02:24:48,065 - INFO - 保存新的最佳模型,验证损失: 3.2677
2025-04-02 02:24:48,772 - INFO - Epoch 3/5 - 训练损失: 3.4221
2025-04-02 02:24:48,845 - INFO - Epoch 3/5 - 验证损失: 3.2677
2025-04-02 02:24:48,899 - INFO - 保存新的最佳模型,验证损失: 3.2677
2025-04-02 02:24:49,583 - INFO - Epoch 4/5 - 训练损失: 3.4220
2025-04-02 02:24:49,653 - INFO - Epoch 4/5 - 验证损失: 3.2677
2025-04-02 02:24:49,706 - INFO - 保存新的最佳模型,验证损失: 3.2677
2025-04-02 02:24:50,380 - INFO - Epoch 5/5 - 训练损失: 3.4219
2025-04-02 02:24:50,450 - INFO - Epoch 5/5 - 验证损失: 3.2677
2025-04-02 02:24:50,496 - INFO - 保存新的最佳模型,验证损失: 3.2677
图像-文本相似度: -0.0156五、高级应用拓展
1. 跨模态检索增强
import torch
import clip
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np
def retrieve_images(query_text, image_db, model, preprocess, device, top_k=5, display=True):
"""
基于CLIP模型的文本到图像检索函数
参数:
query_text: str, 查询文本
image_db: list, 图像路径列表
model: CLIP模型
preprocess: 图像预处理函数
device: 计算设备
top_k: int, 返回最相似的top_k个图像
display: bool, 是否显示结果
返回:
list: 包含(image_path, similarity_score)元组的列表,按相似度降序排列
"""
# 编码查询文本
text_input = clip.tokenize([query_text]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_input)
similarities = []
# 计算每张图像与文本的相似度
for img_path in image_db:
try:
image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
# 计算余弦相似度
sim = torch.cosine_similarity(text_features, image_features)
similarities.append((img_path, sim.item()))
except Exception as e:
print(f"Error processing {img_path}: {str(e)}")
continue
# 按相似度降序排序
sorted_results = sorted(similarities, key=lambda x: -x[1])[:top_k]
if display:
# 显示检索结果
plt.figure(figsize=(15, 5))
plt.suptitle(f'Query: "{query_text}"', fontsize=16)
for i, (img_path, sim_score) in enumerate(sorted_results):
img = Image.open(img_path)
plt.subplot(1, top_k, i+1)
plt.imshow(img)
plt.title(f"Score: {sim_score:.3f}")
plt.axis('off')
plt.tight_layout()
plt.show()
return sorted_results
# 示例使用
if __name__ == "__main__":
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载CLIP模型
model, preprocess = clip.load("ViT-B/32", device=device)
# 准备图像数据库
image_folder = "sample_images" # 替换为你的图像文件夹路径
image_db = [os.path.join(image_folder, f) for f in os.listdir(image_folder)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# 执行查询
query = "a happy dog playing in the park"
results = retrieve_images(query, image_db, model, preprocess, device, top_k=3)
# 打印结果
print("\nTop results:")
for i, (img_path, score) in enumerate(results):
print(f"{i+1}. {img_path} - Similarity: {score:.4f}")输出为:
Top results:
1. sample_images/dog.jpg - Similarity: 0.2151
2. sample_images/bird.jpeg - Similarity: 0.15322. 提示工程优化
import torch
import clip
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
# 设备设置
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_clip_model(model_name="ViT-B/32"):
"""加载CLIP模型和预处理函数"""
model, preprocess = clip.load(model_name, device=device)
print(f"Loaded CLIP {model_name} on {device}")
return model, preprocess
def optimize_prompt(class_name, templates, model, visualize=False):
"""
通过多提示模板优化文本特征表示
参数:
class_name: 目标类别名称(如"cat")
templates: 提示模板列表
model: CLIP模型
visualize: 是否可视化特征空间
返回:
torch.Tensor: 优化后的文本特征向量 [embed_dim]
"""
# 生成多提示文本并编码
text_inputs = torch.cat([clip.tokenize(t.format(class_name)) for t in templates]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 计算平均特征
mean_features = text_features.mean(dim=0, keepdim=True)
mean_features = mean_features / mean_features.norm(dim=-1, keepdim=True)
if visualize:
visualize_features(text_features.cpu().numpy(), templates, class_name)
return mean_features.squeeze(0)
def visualize_features(features, templates, class_name):
"""可视化提示模板生成的特征空间"""
pca = PCA(n_components=2)
reduced = pca.fit_transform(features)
plt.figure(figsize=(10, 8))
plt.scatter(reduced[:, 0], reduced[:, 1], c='blue', s=100)
# 标注每个点对应的模板
for i, (x, y) in enumerate(reduced):
short_template = templates[i].replace("{}", "").strip() or "plain"
plt.annotate(short_template, (x, y), textcoords="offset points", xytext=(0,10), ha='center')
# 绘制平均特征点
mean_point = reduced.mean(axis=0)
plt.scatter(mean_point[0], mean_point[1], c='red', s=200, marker='*')
plt.annotate("Optimized", mean_point, textcoords="offset points", xytext=(0,15), ha='center', color='red')
plt.title(f'Prompt Feature Space for "{class_name}"\n(PCA Projection)')
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.show()
def calculate_similarity(image_feature, text_feature):
"""
安全计算余弦相似度(0-100)
参数:
image_feature: 图像特征 [1, embed_dim]
text_feature: 文本特征 [embed_dim] 或 [1, embed_dim]
"""
if text_feature.dim() == 1:
text_feature = text_feature.unsqueeze(0)
return (100.0 * (image_feature @ text_feature.mT)).item()
def evaluate_prompt(model, preprocess, class_name, prompt_type="optimized"):
"""
评估提示效果
参数:
prompt_type: "optimized" 或 "single"
"""
# 准备测试图像
image_path = f"{class_name}.jpg" # 假设存在类名对应的图像
try:
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
except:
print(f"Test image {image_path} not found, using random image")
image = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
image_feature = model.encode_image(image)
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
if prompt_type == "optimized":
templates = [
"a photo of a {}",
"a bad photo of a {}",
"a cropped photo of the {}",
"a good photo of the {}",
"a low resolution photo of a {}",
"a high resolution photo of a {}",
"a close-up photo of a {}",
"a black and white photo of the {}"
]
text_feature = optimize_prompt(class_name, templates, model, visualize=True)
else:
text_input = clip.tokenize([f"a photo of a {class_name}"]).to(device)
text_feature = model.encode_text(text_input)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature.squeeze(0)
similarity = calculate_similarity(image_feature, text_feature)
print(f"{prompt_type.capitalize()} prompt similarity: {similarity:.2f}")
return similarity
if __name__ == "__main__":
# 1. 加载模型
model, preprocess = load_clip_model()
# 2. 定义测试类别
class_name = "dog" # 替换为您想测试的类别
# 3. 评估单提示和优化提示
print("\n=== Prompt Engineering Evaluation ===")
single_score = evaluate_prompt(model, preprocess, class_name, "single")
optimized_score = evaluate_prompt(model, preprocess, class_name, "optimized")
# 4. 显示改进效果
improvement = optimized_score - single_score
print(f"\nImprovement from prompt engineering: {improvement:.2f} points")
print(f"Relative improvement: {improvement/single_score*100:.1f}%")输出为:
Loaded CLIP ViT-B/32 on cuda
=== Prompt Engineering Evaluation ===
Single prompt similarity: 24.89
Optimized prompt similarity: 25.83
Improvement from prompt engineering: 0.94 points
Relative improvement: 3.8%
(base) root@VM-29-126-ubunt六、总结与展望
本文深入探讨了:
- CLIP架构原理:对比学习目标与双塔设计
- 零样本能力:无需微调的新类别识别
- 跨模态应用:图文检索与特征空间对齐
- 自定义训练:实现领域自适应
在下一篇文章《联邦学习与隐私保护》中,我们将探索如何在分布式环境下实现安全的多模态学习。
关键工具推荐:
pip install clip-anytorch umap-learn应用建议:
- 产品推荐系统使用CLIP实现跨模态搜索
- 内容审核结合提示工程增强分类效果
- 机器人导航通过图文对齐理解环境
- 上一篇:黑帽SEO博文自动跳转到自己网站
- 下一篇:哥们网《风暴大陆》试炼大冒险曝光
相关推荐
- Python入门学习记录之一:变量_python怎么用变量
-
写这个,主要是对自己学习python知识的一个总结,也是加深自己的印象。变量(英文:variable),也叫标识符。在python中,变量的命名规则有以下三点:>变量名只能包含字母、数字和下划线...
- python变量命名规则——来自小白的总结
-
python是一个动态编译类编程语言,所以程序在运行前不需要如C语言的先行编译动作,因此也只有在程序运行过程中才能发现程序的问题。基于此,python的变量就有一定的命名规范。python作为当前热门...
- Python入门学习教程:第 2 章 变量与数据类型
-
2.1什么是变量?在编程中,变量就像一个存放数据的容器,它可以存储各种信息,并且这些信息可以被读取和修改。想象一下,变量就如同我们生活中的盒子,你可以把东西放进去,也可以随时拿出来看看,甚至可以换成...
- 绘制学术论文中的“三线表”具体指导
-
在科研过程中,大家用到最多的可能就是“三线表”。“三线表”,一般主要由三条横线构成,当然在变量名栏里也可以拆分单元格,出现更多的线。更重要的是,“三线表”也是一种数据记录规范,以“三线表”形式记录的数...
- Python基础语法知识--变量和数据类型
-
学习Python中的变量和数据类型至关重要,因为它们构成了Python编程的基石。以下是帮助您了解Python中的变量和数据类型的分步指南:1.变量:变量在Python中用于存储数据值。它们充...
- 一文搞懂 Python 中的所有标点符号
-
反引号`无任何作用。传说Python3中它被移除是因为和单引号字符'太相似。波浪号~(按位取反符号)~被称为取反或补码运算符。它放在我们想要取反的对象前面。如果放在一个整数n...
- Python变量类型和运算符_python中变量的含义
-
别再被小名词坑哭了:Python新手常犯的那些隐蔽错误,我用同事的真实bug拆给你看我记得有一次和同事张姐一起追查一个看似随机崩溃的脚本,最后发现罪魁祸首竟然是她把变量命名成了list。说实话...
- 从零开始:深入剖析 Spring Boot3 中配置文件的加载顺序
-
在当今的互联网软件开发领域,SpringBoot无疑是最为热门和广泛应用的框架之一。它以其强大的功能、便捷的开发体验,极大地提升了开发效率,成为众多开发者构建Web应用程序的首选。而在Spr...
- Python中下划线 ‘_’ 的用法,你知道几种
-
Python中下划线()是一个有特殊含义和用途的符号,它可以用来表示以下几种情况:1在解释器中,下划线(_)表示上一个表达式的值,可以用来进行快速计算或测试。例如:>>>2+...
- 解锁Shell编程:变量_shell $变量
-
引言:开启Shell编程大门Shell作为用户与Linux内核之间的桥梁,为我们提供了强大的命令行交互方式。它不仅能执行简单的文件操作、进程管理,还能通过编写脚本实现复杂的自动化任务。无论是...
- 一文学会Python的变量命名规则!_python的变量命名有哪些要求
-
目录1.变量的命名原则3.内置函数尽量不要做变量4.删除变量和垃圾回收机制5.结语1.变量的命名原则①由英文字母、_(下划线)、或中文开头②变量名称只能由英文字母、数字、下画线或中文字所组成。③英文字...
- 更可靠的Rust-语法篇-区分语句/表达式,略览if/loop/while/for
-
src/main.rs://函数定义fnadd(a:i32,b:i32)->i32{a+b//末尾表达式}fnmain(){leta:i3...
- C++第五课:变量的命名规则_c++中变量的命名规则
-
变量的命名不是想怎么起就怎么起的,而是有一套固定的规则的。具体规则:1.名字要合法:变量名必须是由字母、数字或下划线组成。例如:a,a1,a_1。2.开头不能是数字。例如:可以a1,但不能起1a。3....
- Rust编程-核心篇-不安全编程_rust安全性
-
Unsafe的必要性Rust的所有权系统和类型系统为我们提供了强大的安全保障,但在某些情况下,我们需要突破这些限制来:与C代码交互实现底层系统编程优化性能关键代码实现某些编译器无法验证的安全操作Rus...
- 探秘 Python 内存管理:背后的神奇机制
-
在编程的世界里,内存管理就如同幕后的精密操控者,确保程序的高效运行。Python作为一种广泛使用的编程语言,其内存管理机制既巧妙又复杂,为开发者们提供了便利的同时,也展现了强大的底层控制能力。一、P...
- 一周热门
- 最近发表
- 标签列表
-
- HTML 教程 (33)
- HTML 简介 (35)
- HTML 实例/测验 (32)
- HTML 测验 (32)
- JavaScript 和 HTML DOM 参考手册 (32)
- HTML 拓展阅读 (30)
- HTML文本框样式 (31)
- HTML滚动条样式 (34)
- HTML5 浏览器支持 (33)
- HTML5 新元素 (33)
- HTML5 WebSocket (30)
- HTML5 代码规范 (32)
- HTML5 标签 (717)
- HTML5 标签 (已废弃) (75)
- HTML5电子书 (32)
- HTML5开发工具 (34)
- HTML5小游戏源码 (34)
- HTML5模板下载 (30)
- HTTP 状态消息 (33)
- HTTP 方法:GET 对比 POST (33)
- 键盘快捷键 (35)
- 标签 (226)
- opacity 属性 (32)
- transition 属性 (33)
- 1-1. 变量声明 (31)
