图像分类:对google/vit-large-patch32-384模型进行微调

背景:

图像分类是很常见的场景,比如日常的物体识别,可很多时候我们还需要根据自己的数据来训练自己的分类,大模型出现以后,我们不需要再重头来训练我们的模型,直接根据已经训练好的大模型进行微调即可,本文就针对google/vit-large-patch32-384模型进行微调。

数据准备:

数据准备参考huggingface如何加载本地数据集进行大模型训练-CSDN博客

代码:

import json
import os
from PIL import Image
from datasets import Dataset
from sklearn.metrics import accuracy_score,f1_score, recall_score
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from transformers import DefaultDataCollator

path = '/data/dataset/image'


def gen(path):
    image_json = os.path.join(path, "image.json")
    with open(image_json, 'r') as f:
        # 读取JSON数据
        data = json.load(f)
    for key, value in data.items():
        imagePath = os.path.join(path, "image")
        imagePath = os.path.join(imagePath, key)
        image = Image.open(imagePath)
        yield {'image': image, 'label': value}


def get_label(path):
    label_json = os.path.join(path, "label.json")
    with open(label_json, 'r') as f:
        # 读取JSON数据
        data = json.load(f)
    label2id, id2label = dict(), dict()
    for key, value in data.items():
        label2id[key] = str(value)
        id2label[str(value)] = key
    return label2id, id2label


ds = Dataset.from_generator(gen, gen_kwargs={"path": path})
ds = ds.train_test_split(test_size=0.2)
label2id, id2label = get_label(path)

checkpoint = "/data/model/vit-large-patch32-384"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])


def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples


ds = ds.with_transform(transforms)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels,preds,average="weighted")
    acc = accuracy_score(labels,preds)
    recall = recall_score(labels,preds,average="weighted")
    return {"accuracy":acc,"f1":f1, "recall": recall}


model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=5,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True)

training_args = TrainingArguments(
    output_dir="my_awesome_food_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    greater_is_better=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

data_collator = DefaultDataCollator()

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()

训练结果:

{
  "best_metric": 1.0,
  "best_model_checkpoint": "my_awesome_food_model/checkpoint-2",
  "epoch": 3.0,
  "eval_steps": 500,
  "global_step": 3,
  "is_hyper_param_search": false,
  "is_local_process_zero": true,
  "is_world_process_zero": true,
  "log_history": [
    {
      "epoch": 1.0,
      "eval_accuracy": 0.0,
      "eval_f1": 0.0,
      "eval_loss": 1.8605551719665527,
      "eval_recall": 0.0,
      "eval_runtime": 0.1864,
      "eval_samples_per_second": 10.727,
      "eval_steps_per_second": 5.363,
      "step": 1
    },
    {
      "epoch": 2.0,
      "eval_accuracy": 1.0,
      "eval_f1": 1.0,
      "eval_loss": 1.2016913890838623,
      "eval_recall": 1.0,
      "eval_runtime": 0.175,
      "eval_samples_per_second": 11.43,
      "eval_steps_per_second": 5.715,
      "step": 2
    },
    {
      "epoch": 3.0,
      "eval_accuracy": 1.0,
      "eval_f1": 1.0,
      "eval_loss": 0.8268076181411743,
      "eval_recall": 1.0,
      "eval_runtime": 0.1774,
      "eval_samples_per_second": 11.271,
      "eval_steps_per_second": 5.635,
      "step": 3
    }
  ],
  "logging_steps": 10,
  "max_steps": 3,
  "num_train_epochs": 3,
  "save_steps": 500,
  "total_flos": 1.946783884640256e+16,
  "trial_name": null,
  "trial_params": null
}

注意事项:

第一次进行训练的时候,控制台报了如下的异常:

(.env) (base) [ipa_sudo@comm-agi image]$ python vit.py 
Traceback (most recent call last):
  File "/data/image/vit.py", line 72, in <module>
    model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=5, id2label=id2label, label2id=label2id)
  File "/data/.env/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 516, in from_pretrained
    return model_class.from_pretrained(
  File "/data/.env/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3091, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/data/.env/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3532, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for ViTForImageClassification:
        size mismatch for classifier.weight: copying a param with shape torch.Size([1000, 1024]) from checkpoint, the shape in current model is torch.Size([5, 1024]).
        size mismatch for classifier.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([5]).
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

从这个提示可以看出来,原来的模型支持的是1000个分类,而我当前传入的是5个分类,所以向量的维度不一致。

怎么解决这个问题呢,其实它已经给了解决方案,那就是在方法AutoModelForImageClassification.from_pretrained()里增加一个ignore_mismatched_sizes=True参数即可。

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=5,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True)

总结:

1、比较繁琐的是图片数据的准备,既需要下载图片,还需要标注图片

2、图片数据需要进行变换,这样能增加模型的鲁棒性

3、训练参数中需要增加remove_unused_columns=False,否则Trainer会删除image这样,然后后面就没法计算pixel_values字段了。