0% read

Fix Gemma 4 mm_token_type_ids Error in Fine-Tuning

Jun 2, 2026

You started a perfectly ordinary text-only fine-tune of Gemma 4, and the trainer died on the first step with this:

ValueError: `mm_token_type_ids` is required as a model input when training

No images. No audio. Just chat data. So why is Gemma 4 demanding a multimodal token field? This is a known rough edge in the day-zero Gemma 4 release (transformers issue #45200), and the fix is small once you understand it. This guide gives you a one-line workaround and a complete, copy-paste custom data collator that makes the mm_token_type_ids error disappear for good.

The Error: What You're Actually Seeing

The traceback bottoms out inside the model's forward pass:

File ".../transformers/models/gemma4/modeling_gemma4.py", line ..., in forward
    raise ValueError(
ValueError: `mm_token_type_ids` is required as a model input when training

It fires on the very first training step, regardless of your dataset, LoRA config, or GPU. People hit it with plain Trainer, with TRL's SFTTrainer, and with QLoRA setups alike. The common thread: a text-only dataset being fed to a model that was built multimodal-first.

Why This Happens (Root Cause)

Gemma 4 is multimodal by design. To tell text tokens apart from image and audio tokens, the architecture introduced two extra inputs alongside input_ids:

  • token_type_ids — the standard segment field.
  • mm_token_type_ids — the multimodal token-type field that marks which positions are text vs. image vs. audio.

The problem is that the Gemma 4 model code validates the presence of mm_token_type_ids during a training forward pass that computes loss (training mode with labels present) — even when there isn't a single image in your batch. Meanwhile, the standard tokenizer and the default data collators never produce that field. For a text-only run every value would simply be 0, but because the field is missing entirely rather than zero-filled, the model raises instead of assuming a sensible default. (As of the day-zero transformers==5.5.0.dev0 source build, the relevant raise lives in the training branch of modeling_gemma4.py's forward; exact line numbers move between commits.)

So the bug is really a mismatch: the model requires mm_token_type_ids, your tokenizer never creates it, and the trainer throws it away even if you do. Fix all three and the error is gone.

The Quick Fix (One Line)

If you control the loop and just want to unblock training right now, inject a zero tensor before the forward call (make sure torch is imported):

import torch

inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])

For text-only data, every token is a text token, so an all-zeros tensor shaped like input_ids is exactly correct. This is enough to get a manual training loop moving. But with Trainer / SFTTrainer you don't own the forward call, so you need the field baked into the batch itself — which is what the collator below does.

The Complete Fix: A Custom Gemma 4 Data Collator

There are three pieces, and you need all three. Skip any one and the mm_token_type_ids error comes right back.

0. Minimal setup

So the snippets below are runnable end to end, here's the scaffolding they assume — swap in your own model and dataset:

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "google/gemma-4-4b-it"  # any Gemma 4 checkpoint
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

# This guide assumes each example has a `messages` field (chat format):
#   {"messages": [{"role": "user", "content": "..."},
#                 {"role": "assistant", "content": "..."}]}
# If your data is instruction/output style, convert it to messages first.
dataset = load_dataset("your/dataset", split="train")

1. Add both fields during tokenization

When you tokenize, attach token_type_ids and mm_token_type_ids as all-zeros lists the same length as input_ids:

def format_chat(example):
    text = tokenizer.apply_chat_template(
        example["messages"], tokenize=False, add_generation_prompt=False
    )
    tokenized = tokenizer(text, truncation=True, max_length=4096)
    tokenized["token_type_ids"] = [0] * len(tokenized["input_ids"])
    tokenized["mm_token_type_ids"] = [0] * len(tokenized["input_ids"])
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

tokenized_dataset = dataset.map(format_chat, remove_columns=dataset.column_names)

Note this computes loss over the full sequence (prompt + completion). If you want standard SFT behavior — loss on the assistant reply only — mask the prompt tokens by setting their labels to -100 here.

2. Use a collator that preserves and pads them

A default collator won't know how to pad these custom fields, so define your own. It pads input_ids, builds the attention_mask, zero-fills both token-type fields to the batch's max length, and masks padding in labels with -100:

from dataclasses import dataclass
import torch

@dataclass
class GemmaCollator:
    tokenizer: object

    def __call__(self, features):
        max_len = max(len(f["input_ids"]) for f in features)
        # Gemma tokenizers define a pad token; fall back to eos just in case.
        pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        batch = {
            "input_ids": [],
            "attention_mask": [],
            "token_type_ids": [],
            "mm_token_type_ids": [],
            "labels": [],
        }
        for f in features:
            pad_len = max_len - len(f["input_ids"])
            batch["input_ids"].append(f["input_ids"] + [pad_id] * pad_len)
            batch["attention_mask"].append([1] * len(f["input_ids"]) + [0] * pad_len)
            batch["token_type_ids"].append([0] * max_len)
            batch["mm_token_type_ids"].append([0] * max_len)
            batch["labels"].append(
                f.get("labels", f["input_ids"]) + [-100] * pad_len
            )
        return {k: torch.tensor(v) for k, v in batch.items()}

3. Set remove_unused_columns=False

This is the step everyone forgets. By default the trainer strips any column the model's signature doesn't "recognize" before the collator runs, so your carefully added mm_token_type_ids gets deleted in transit and the error returns. Turn that off:

from trl import SFTConfig, SFTTrainer

training_args = SFTConfig(
    output_dir="gemma4-finetune",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    learning_rate=2e-4,
    remove_unused_columns=False,                       # <-- critical: keep mm_token_type_ids
    dataset_kwargs={"skip_prepare_dataset": True},     # we pre-tokenized ourselves
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=GemmaCollator(tokenizer),
)

trainer.train()

With the field created in step 1, preserved and padded in step 2, and protected from column-stripping in step 3, the model gets its mm_token_type_ids and training proceeds normally.

Verifying the Fix

Before launching a full run, dump one batch and confirm mm_token_type_ids is actually present and zero-shaped:

collator = GemmaCollator(tokenizer)
batch = collator([tokenized_dataset[0], tokenized_dataset[1]])
print(batch.keys())
print("mm_token_type_ids" in batch)            # True
print(batch["mm_token_type_ids"].shape)        # matches input_ids
print(batch["mm_token_type_ids"].sum().item()) # 0 for text-only

If mm_token_type_ids is in the keys and its sum is 0, you're done — the next trainer.train() will run past step one.

  • Only adding token_type_ids. The two fields are different. Gemma 4 specifically checks mm_token_type_ids; the non-mm field alone won't satisfy the validation.
  • Forgetting remove_unused_columns=False. This is the most common cause of "I added the field but the error is still there." The trainer silently drops it first.
  • Letting the collator pad token-type fields with the wrong length. They must match the padded input_ids length, not the original, or you'll get a shape-mismatch deeper in the forward pass.
  • token_type_ids is required on Gemma 3. The earlier Gemma 3 generation throws the non-mm variant of this same error during TRL fine-tuning (trl issue #5032). The same zero-fill strategy applies.
  • QLoRA crashing separately. If 4-bit Gemma 4 fine-tuning fails with an unrelated adapter error, that's a different known issue (peft issue #3129), not the collator. Fix the mm_token_type_ids error first, then tackle quantization.

For broader fine-tuning setup — LoRA vs. QLoRA, Unsloth, dataset prep, and GGUF export — see our Gemma 4 fine-tuning guide.

Will This Be Fixed Upstream?

Probably. The proposal in transformers issue #45200 is for the model to default mm_token_type_ids to zeros when it's absent during text-only training, rather than raising. Until that lands in a stable release (the error currently reproduces on the transformers==5.5.0.dev0 source build), the collator above is the reliable workaround. You can track the model's expected inputs on the official Gemma model docs. Once a fixed version ships, you can drop the custom collator and go back to a standard one — but keeping it does no harm.

FAQ

Do I need mm_token_type_ids if I'm fine-tuning on images? Yes, and then it actually carries signal — the multimodal token positions should be marked, not all-zero. The all-zeros shortcut here is only valid for text-only data.

I'm using Unsloth. Does this apply? Partly. Unsloth wraps the model and tokenizer and ships its own data handling, so don't drop this collator in blindly. First update Unsloth and re-test — its patches may already inject mm_token_type_ids for you. If the error persists, apply the same idea inside Unsloth's flow: zero-fill mm_token_type_ids in your dataset formatting (or its data_collator) and keep remove_unused_columns=False, rather than swapping in this collator wholesale.

Can I just downgrade transformers to avoid this? You can pin to a release before Gemma 4 support, but then you can't load Gemma 4 at all. The collator fix is safer than version-pinning.

Does this affect inference too? No. The check is gated on the training loss path. Plain generation/inference doesn't require you to pass mm_token_type_ids for text-only prompts.

I'm using plain Trainer, not TRL. Does this still work? Yes. The GemmaCollator is framework-agnostic — pass it as data_collator to a vanilla Trainer and set remove_unused_columns=False in your TrainingArguments.

Next Steps

gemma4 — interact

Stop reading. Start building.

~/gemma4 $ Get hands-on with the models discussed in this guide. No deployment, no friction, 100% free playground.

Launch Playground />
Gemma 4 AI

Gemma 4 AI

Related Guides

Fix Gemma 4 mm_token_type_ids Error in Fine-Tuning | Blog