from transformers import BlipForConditionalGeneration, BlipProcessor, AutoTokenizer, AdamW
from PIL import Image
from datasets import load_dataset

processor = BlipProcessor.from_pretrained("huggingface.co/Salesforce/blip-image-captioning-base")
bertTokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

input_dataset = load_dataset(path="data", data_files="data.csv")

image_path_list = [Image.open(img_path) for img_path in input_dataset["train"]["image_path"]]

image_inputs = processor(image_path_list, return_tensors="pt")
text_inputs = bertTokenizer(input_dataset["train"]["caption"],
                                max_length=128,
                                padding="max_length",
                                truncation=True,
                                add_special_tokens=False,  
                                return_tensors="pt",
                                return_token_type_ids=False)

pixel_values = image_inputs["pixel_values"]
text_ids = text_inputs["input_ids"]
attention_mask = text_inputs["attention_mask"]

# 从零训练用 BlipForConditionalGeneration.from_config()
model = BlipForConditionalGeneration.from_pretrained("huggingface.co/Salesforce/blip-image-captioning-base")

learning_rate = 5e-5
epochs = 30
optimizer = AdamW(model.parameters(), lr=learning_rate)

model.train()
for epoch in range(epochs):
        outputs = model.forward(pixel_values=pixel_values,
                                input_ids=text_ids,
                                attention_mask=attention_mask,
                                labels=text_ids)

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Epoch {epoch + 1}/{epochs} completed. Loss: {loss.item()}")

model.eval()

output_batch = model.generate(pixel_values=pixel_values, max_length=128)
for i in range(0, output_batch.shape[0]):
    caption = bertTokenizer.decode(output_batch[i], skip_special_tokens=True)
    print(caption)

注意输入的text_ids数据开头要有一个[PAD]/0数据,因为labels后续处理会把input_ids右移1位