0% read

Gemma 4 微调报错 mm_token_type_ids 解决方法

2026/06/02

一次再普通不过的 Gemma 4 纯文本微调,trainer 刚跑到第一步就挂了,丢出这么一句:

ValueError: `mm_token_type_ids` is required as a model input when training

没图像、没音频,喂进去的就是一堆聊天数据。那 Gemma 4 凭什么张口要一个多模态的 token 字段?这是 Gemma 4 首发版本里一个已知的毛刺(transformers issue #45200),想明白之后改起来其实很小。这篇给出一行临时绕过方案,外加一份完整、可直接复制的自定义 data collator,把 mm_token_type_ids 报错一次性根治。

报错到底在说什么

traceback 一路下探,最后落在模型的 forward 里:

File ".../transformers/models/gemma4/modeling_gemma4.py", line ..., in forward
    raise ValueError(
ValueError: `mm_token_type_ids` is required as a model input when training

它在训练的第一步就触发,跟你用什么数据集、LoRA 配置或者哪块 GPU 都没关系。用原生 Trainer 的会撞上,用 TRL 的 SFTTrainer 的会撞上,跑 QLoRA 的同样会撞上。共同点只有一个:拿一份纯文本数据集,去喂一个生来就以多模态为先的模型。

为什么会这样(根本原因)

Gemma 4 在设计上就是多模态的。为了把文本 token 跟图像、音频 token 区分开,它的架构在 input_ids 之外多引入了两个输入:

  • token_type_ids —— 标准的 segment 字段。
  • mm_token_type_ids —— 多模态的 token 类型字段,用来标记每个位置是文本、图像还是音频。

问题出在:Gemma 4 的模型代码会在计算 loss 的训练 forward(训练模式且带 labels)里强制校验 mm_token_type_ids 是否存在——哪怕你这一个 batch 里连一张图都没有。可偏偏标准 tokenizer 和默认的 data collator 从来不会产出这个字段。对纯文本任务来说,这个字段每个值本应都是 0,但因为它是整个缺席而不是被填成 0,模型选择直接抛错,而不是默认补一个合理值。(在首发的 transformers==5.5.0.dev0 源码构建里,相关的 raise 就藏在 modeling_gemma4.pyforward 的训练分支里;具体行号会随 commit 变动。)

所以这个 bug 本质是三方错位:模型要求mm_token_type_ids,你的 tokenizer 从不创建它,而 trainer 就算你建了也会把它扔掉。三处一起修好,报错才会消失。

临时绕过(一行)

如果训练循环在你自己手里,只想现在就把训练跑起来,那就在调用 forward 之前注入一个全零张量(记得 import torch):

import torch

inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])

纯文本数据里每个 token 都是文本 token,所以一个跟 input_ids 同形的全零张量恰好就是正确答案。对手写训练循环来说,这一行就够了。但用 Trainer / SFTTrainer 时 forward 不归你管,你得让这个字段直接长在 batch 里——下面的 collator 干的就是这件事。

完整方案:一个自定义的 Gemma 4 Data Collator

一共三块,缺一不可。漏掉任何一块,mm_token_type_ids 报错都会原样回来。

0. 最小可运行环境

为了让下面的片段能端到端跑通,先把它们假设的脚手架摆出来——换成你自己的模型和数据集即可:

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "google/gemma-4-4b-it"  # any Gemma 4 checkpoint
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

# This guide assumes each example has a `messages` field (chat format):
#   {"messages": [{"role": "user", "content": "..."},
#                 {"role": "assistant", "content": "..."}]}
# If your data is instruction/output style, convert it to messages first.
dataset = load_dataset("your/dataset", split="train")

1. 在分词时把两个字段都加上

分词的时候,给 token_type_idsmm_token_type_ids 各挂一个与 input_ids 等长的全零列表:

def format_chat(example):
    text = tokenizer.apply_chat_template(
        example["messages"], tokenize=False, add_generation_prompt=False
    )
    tokenized = tokenizer(text, truncation=True, max_length=4096)
    tokenized["token_type_ids"] = [0] * len(tokenized["input_ids"])
    tokenized["mm_token_type_ids"] = [0] * len(tokenized["input_ids"])
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

tokenized_dataset = dataset.map(format_chat, remove_columns=dataset.column_names)

注意这种写法是在整条序列(prompt + completion)上算 loss。如果你想要标准的 SFT 行为——只在 assistant 回复上算 loss——那就在这里把 prompt 部分的 labels 设成 -100 屏蔽掉。

2. 用一个会保留并补齐它们的 collator

默认 collator 不知道该怎么 pad 这些自定义字段,所以得自己写一个。它负责 pad input_ids、构造 attention_mask、把两个 token 类型字段按 batch 内最大长度补零,并用 -100labels 里的 padding 屏蔽掉:

from dataclasses import dataclass
import torch

@dataclass
class GemmaCollator:
    tokenizer: object

    def __call__(self, features):
        max_len = max(len(f["input_ids"]) for f in features)
        # Gemma tokenizers define a pad token; fall back to eos just in case.
        pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        batch = {
            "input_ids": [],
            "attention_mask": [],
            "token_type_ids": [],
            "mm_token_type_ids": [],
            "labels": [],
        }
        for f in features:
            pad_len = max_len - len(f["input_ids"])
            batch["input_ids"].append(f["input_ids"] + [pad_id] * pad_len)
            batch["attention_mask"].append([1] * len(f["input_ids"]) + [0] * pad_len)
            batch["token_type_ids"].append([0] * max_len)
            batch["mm_token_type_ids"].append([0] * max_len)
            batch["labels"].append(
                f.get("labels", f["input_ids"]) + [-100] * pad_len
            )
        return {k: torch.tensor(v) for k, v in batch.items()}

3. 设置 remove_unused_columns=False

这一步是所有人都会忘的。trainer 默认会在 collator 运行之前,把任何模型签名"不认识"的列剔除掉,于是你辛辛苦苦加上的 mm_token_type_ids 半路就被删了,报错卷土重来。把这个行为关掉:

from trl import SFTConfig, SFTTrainer

training_args = SFTConfig(
    output_dir="gemma4-finetune",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    learning_rate=2e-4,
    remove_unused_columns=False,                       # <-- critical: keep mm_token_type_ids
    dataset_kwargs={"skip_prepare_dataset": True},     # we pre-tokenized ourselves
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=GemmaCollator(tokenizer),
)

trainer.train()

字段在第一步被创建、在第二步被保留并补齐、在第三步被保住不被列剔除,模型就拿到了它要的 mm_token_type_ids,训练随之正常推进。

验证修复是否生效

正式开跑之前,先 dump 一个 batch 出来,确认 mm_token_type_ids 真的在、且形状全零:

collator = GemmaCollator(tokenizer)
batch = collator([tokenized_dataset[0], tokenized_dataset[1]])
print(batch.keys())
print("mm_token_type_ids" in batch)            # True
print(batch["mm_token_type_ids"].shape)        # matches input_ids
print(batch["mm_token_type_ids"].sum().item()) # 0 for text-only

只要 mm_token_type_ids 出现在 keys 里、并且求和为 0,就成了——下一次 trainer.train() 能顺利越过第一步。

常见坑与相关报错

  • 只加了 token_type_ids 这两个字段是两回事。Gemma 4 校验的明确是 mm_token_type_ids;光有不带 mm 的那个,过不了校验。
  • 忘了 remove_unused_columns=False "我明明加了字段,报错却还在"——这是最常见的原因。trainer 不声不响地先把它丢了。
  • collator 把 token 类型字段补成了错的长度。 它们必须跟补齐后的 input_ids 长度一致,而不是原始长度,否则会在 forward 更深处冒出 shape 不匹配的错误。
  • Gemma 3 上的 token_type_ids is required 更早的 Gemma 3 这一代,在 TRL 微调时会抛出同类报错的非 mm 版本(trl issue #5032)。同样用全零补齐的思路即可解决。
  • QLoRA 单独崩溃。 如果 4-bit Gemma 4 微调挂在一个跟此无关的 adapter 报错上,那是另一个已知问题(peft issue #3129),不是 collator 的锅。先解决 mm_token_type_ids 报错,再去搞量化。

想了解更全面的微调配置——LoRA 与 QLoRA 之争、Unsloth、数据准备、GGUF 导出——可以看我们的 Gemma 4 微调指南

上游会修掉这个问题吗

大概率会。transformers issue #45200 里的提案,是让模型在纯文本训练中遇到 mm_token_type_ids 缺席时默认补零,而不是抛错。在这个改动进入稳定版之前(目前在 transformers==5.5.0.dev0 源码构建上仍能复现),上面这份 collator 是稳妥可靠的绕过方案。你可以在官方 Gemma 模型文档 上跟进模型期望的输入。等修好的版本发布后,你大可以丢掉这个自定义 collator、换回标准的——不过留着它也没什么坏处。

FAQ

如果我是在图像数据上微调,还需要 mm_token_type_ids 吗? 需要,而且这时它是真的带信息的——多模态的 token 位置应该被正确标出来,而不是全填零。本文里全零的捷径只对纯文本数据成立。

我用的是 Unsloth,这套方法适用吗? 部分适用。Unsloth 把模型和 tokenizer 都包了一层,还自带了一套数据处理,所以别把这个 collator 不管三七二十一直接塞进去。先升级 Unsloth 重测一遍——它的补丁可能已经替你注入了 mm_token_type_ids。如果报错依旧,就把同一个思路套进 Unsloth 自己的流程里:在你的数据格式化(或它的 data_collator)里给 mm_token_type_ids 补零,并保持 remove_unused_columns=False,而不是整个换成这里的 collator。

我能直接降级 transformers 来躲开这个问题吗? 你可以把版本钉在 Gemma 4 支持之前的某个 release,但那样你就根本加载不了 Gemma 4 了。collator 这套修法比锁版本安全得多。

这会影响推理吗? 不会。这个校验只挂在训练的 loss 路径上。普通的生成/推理,对纯文本 prompt 来说并不要求你传 mm_token_type_ids

我用的是原生 Trainer,不是 TRL,这套还能用吗? 能。GemmaCollator 跟框架无关——把它当 data_collator 传给原生 Trainer,并在 TrainingArguments 里设 remove_unused_columns=False 即可。

下一步

gemma4 — interact

Stop reading. Start building.

~/gemma4 $ Get hands-on with the models discussed in this guide. No deployment, no friction, 100% free playground.

Launch Playground />
Gemma 4 AI

Gemma 4 AI

相关教程

Gemma 4 微调报错 mm_token_type_ids 解决方法 | 博客