Looking for guidance. Trying to create a model with TrOCR's encoder + Google's mT5 multilingual decoder but model fails to overfit on a single data sample

Hi everyone,

I am working on building a proof of concept for OCR system that can recognize both handwritten and printed Hindi (Devanagari) text in complex documents. I’m trying to build on top of TrOCR (microsoft/trocr-base-handwritten) since it already has a strong vision encoder trained for handwriting recognition.

The core problem I’m running into is on the decoder/tokenizer side — TrOCR’s default decoder and tokenizer are trained for English only, and I need Hindi output.

What I’ve tried so far:

I replaced TrOCR’s decoder with google/mt5-small, which natively supports Hindi tokenization. The hidden sizes matched, so I expected this to work.

However, the model failed to overfit even on single data point. The loss comes down but hovers at near 2-3 at the end, and the characters keep repeating instead of forming a meaningful word or the sentence. I have tried changing learning rate, introducing repetition penalty but overfitting just don’t happen.

I need guidance as is their any other tokenizer out there that can work well with TrOCR’s encoder or can you help me improve in this current setup (TrOCR’s encoder+Decoder).

1 Like

Hmm…


This is fixable enough to keep exploring, but the main problem is probably not the tokenizer itself. The bigger problem is that your current experiment combines a TrOCR encoder that was fine-tuned for English single-line handwriting, a custom mT5-as-decoder-only wiring path, and a difficult Hindi OCR target. That is a fragile combination. Hugging Face’s encoder-decoder docs explicitly warn that when you combine a pretrained encoder and a different decoder, the cross-attention layers may be randomly initialized and must be learned during fine-tuning. They also show that the supported decoder path is usually a decoder model configured for cross-attention, not a full seq2seq model hacked into decoder-only use. (Hugging Face)

The most important conclusion

I do not think your experiment proves that “TrOCR encoder + Hindi-capable decoder cannot work.” I think it proves that your current wiring and training regime are too unstable to make that judgment. The fact that loss drops at all means the image path, label path, and cross-modal connection are at least partially alive. The repeated characters point more toward autoregressive instability than “complete failure.” Repetition is also a known failure mode in TrOCR-style generation, especially when decoder setup or generation config is off. (Hugging Face)

What is going wrong in your current Colab

From the code you shared, these are the biggest issues.

1. Your notebook says mt5-small, but the code loads mt5-base

That is not a cosmetic detail. mt5-base is materially larger and harder to stabilize than mt5-small. For a one-sample overfit test, you want the smallest model that can still express the task. Using a larger multilingual decoder makes the bridge-learning problem harder, not easier.

2. You are starting from trocr-base-handwritten, which is already specialized

The public model card says microsoft/trocr-base-handwritten is a TrOCR model fine-tuned on the IAM dataset. The updated README also says it works best on single-line handwritten English text and is not optimized for printed text or multi-line inputs. For a language swap, trocr-base-stage1 or trocr-small-stage1 is usually a cleaner starting point because those are the pre-trained only checkpoints rather than the already English-finetuned handwritten checkpoint. (Hugging Face)

3. The mT5 wiring path is custom, and that matters

You are not using the standard VisionEncoderDecoderModel.from_encoder_decoder_pretrained(...) path. Instead, you replace the mT5 encoder with a dummy module and feed encoder_outputs directly into MT5ForConditionalGeneration. That can work, but public Hugging Face issue history shows that using T5 or ByT5 as decoder-only for OCR is still a custom workaround path, not the most standard one. There is a dedicated issue where a user had to create a T5DecoderOnlyForCausalLM subclass for this exact reason. (GitHub)

4. Your one-sample overfit test is not a clean overfit test

In your code, the one-sample test uses:

  • full trainable model
  • AdamW(lr=1e-3)
  • only 150 steps
  • beam search during evaluation

That is too aggressive and too noisy. T5-family docs say that with AdamW, values around 1e-4 to 3e-4 typically work well, and they note that T5 was pretrained with Adafactor. Also, for T5 and mT5, the correct decoder start behavior is to use pad_token_id. (Hugging Face)

So your current test is mixing three confounders:

  • LR is likely too high for this hybrid.
  • The decoder is larger than needed.
  • Beam search is a poor judge of early training quality.

5. Decoder masking is too implicit

The official encoder-decoder implementation uses shift-right logic to build decoder inputs from labels. There is also a recent Transformers issue pointing out that in VisionEncoderDecoderModel, users observed that decoder_attention_mask was not always created the way they expected when labels were shifted into decoder inputs. In a custom hybrid like yours, I would not leave this implicit. I would create decoder_input_ids and decoder_attention_mask explicitly. (GitHub)

My recommendation for your current setup

Keep the overall idea for now, but simplify the experiment hard.

Recommended first rebuild

Use:

  • microsoft/trocr-small-stage1 or microsoft/trocr-base-stage1
  • google/mt5-small
  • explicit decoder inputs and decoder attention mask
  • frozen encoder at first
  • greedy decoding
  • lower LR
  • longer one-sample training

Why this version first:

  • stage1 is a cleaner visual warm start than the English IAM handwritten checkpoint for a decoder swap. (Hugging Face)
  • mt5-small is easier to stabilize than mt5-base.
  • mT5 already supports Hindi tokenization and uses pad_token_id as the decoder start token, so the tokenizer is not the core blocker. (Hugging Face)

Concrete changes I would make

A. Change the checkpoints

Use:

trocr = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-stage1")
mt5_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
image_processor = ViTImageProcessor.from_pretrained("microsoft/trocr-small-stage1")

This removes two sources of instability at once: an over-specialized English handwritten checkpoint and an unnecessarily large decoder. The stage1 models are the pre-trained-only TrOCR checkpoints. (Hugging Face)

B. Set both model config and generation config

Do this:

model.mt5.config.decoder_start_token_id = tokenizer.pad_token_id
model.mt5.config.pad_token_id = tokenizer.pad_token_id
model.mt5.config.eos_token_id = tokenizer.eos_token_id
model.mt5.config.use_cache = False

model.mt5.generation_config.decoder_start_token_id = tokenizer.pad_token_id
model.mt5.generation_config.pad_token_id = tokenizer.pad_token_id
model.mt5.generation_config.eos_token_id = tokenizer.eos_token_id

mT5 uses pad_token_id to start decoder generation. That part of your code is conceptually right, but I would set generation_config too. (Hugging Face)

C. Make decoder inputs explicit

Inside forward, do not rely only on labels=... to do everything.

def forward(self, pixel_values, labels=None):
    hidden = self._encode(pixel_values)

    decoder_input_ids = None
    decoder_attention_mask = None

    if labels is not None:
        decoder_input_ids = self.mt5._shift_right(labels)
        decoder_attention_mask = (decoder_input_ids != self.mt5.config.pad_token_id).long()

    return self.mt5(
        encoder_outputs=BaseModelOutput(last_hidden_state=hidden),
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=decoder_attention_mask,
        labels=labels,
        use_cache=False,
    )

This makes the training path less ambiguous, and it lines up with how encoder-decoder training is supposed to work conceptually: labels are shifted right into decoder inputs. (GitHub)

D. Freeze the encoder first

At the beginning, the fragile part is the bridge, not the vision backbone. So start with:

for p in model.encoder.parameters():
    p.requires_grad = False

for name, p in model.mt5.named_parameters():
    p.requires_grad = (
        ("EncDecAttention" in name) or
        ("lm_head" in name) or
        ("shared" in name)
    )

if model.enc_to_dec_proj is not None:
    for p in model.enc_to_dec_proj.parameters():
        p.requires_grad = True

This follows directly from the encoder-decoder warm-start logic: the cross-attention bridge is new and needs to be learned carefully. (Hugging Face)

E. Fix the one-sample overfit protocol

For the one-sample proof, use:

  • lr=1e-4
  • weight_decay=0.0
  • no dropout
  • greedy decode
  • 500 to 1000 steps
  • teacher-forced token accuracy

Example:

optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-4,
    weight_decay=0.0,
)

for m in model.modules():
    if isinstance(m, nn.Dropout):
        m.p = 0.0

for step in range(1, 1001):
    outputs = model(pixel_values=pv, labels=lb)
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % 20 == 0:
        model.eval()
        with torch.no_grad():
            tf_outputs = model(pixel_values=pv, labels=lb)
            tf_pred = tf_outputs.logits.argmax(-1)
            mask = lb != -100
            token_acc = (tf_pred[mask] == lb[mask]).float().mean().item()

            gen_ids = model.generate(
                pixel_values=pv,
                max_new_tokens=int(mask.sum().item()) + 4,
                num_beams=1,
                do_sample=False,
            )
            pred = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

        print(step, loss.item(), token_acc, pred)
        model.train()

The T5 docs support the lower LR recommendation. Greedy decoding removes beam-search noise from the diagnosis. (Hugging Face)

What success should look like

Do not judge by loss alone.

For a one-sample test, success is:

  1. teacher-forced token accuracy approaches 1.0
  2. greedy decoded text becomes an exact match
  3. it stays stable for multiple checks

If loss goes down but token accuracy stays mediocre, the bridge is not learning properly. If token accuracy gets high but free decoding still loops, the model is learning under teacher forcing but autoregressive generation is unstable.

About the tokenizer question

The practical answer is:

Do not think “which tokenizer works with TrOCR encoder?”
Think “which decoder family works best with the TrOCR encoder?”

The tokenizer comes with the decoder family.

Best current options, in order

1. XLM-R decoder

This is the cleanest TrOCR-style multilingual path inside Transformers. Hugging Face’s public decoder-replacement guidance explicitly shows replacing TrOCR’s decoder with RobertaForCausalLM.from_pretrained("xlm-roberta-base", is_decoder=True, add_cross_attention=True). That is the most standard multilingual replacement route. (Hugging Face Forums)

Why it is attractive:

  • closer to the standard VisionEncoderDecoderModel recipe
  • easier than custom T5-decoder-only plumbing
  • multilingual tokenizer already available

2. IndicBART

If your real target is Hindi and perhaps other Indian languages, this is one of the strongest alternatives. IndicBART is a multilingual seq2seq model focused on 11 Indian languages plus English. There is also a public trocr-indic model built around IndicBART, and it explicitly supports Hindi, though it notes a Devanagari-script limitation in the released setup. (Hugging Face)

Why it is attractive:

  • more language-focused for Indic text than mT5
  • seq2seq architecture fits OCR-style generation naturally
  • smaller and more targeted than mt5-base

3. ByT5

ByT5 is tokenizer-free and works directly on UTF-8 bytes. The model docs say it is more robust to noise and can process any language without a separate tokenizer vocabulary. That is interesting for OCR because OCR errors often look like noisy character sequences. (Hugging Face)

Why it is attractive:

  • no tokenizer coverage problem
  • strong fit for noisy OCR text

Why I would not pick it first:

  • sequence lengths are longer
  • it still lives in the T5 family, so the decoder-only integration pain remains

4. Stay with mT5

This is still viable. mT5 covers 101 languages and already supports Hindi tokenization. I would keep it only after fixing the wiring and training regime first. (Hugging Face)

My recommendation on alternatives

If your goal is the least risky next step, I would rank them like this:

  1. TrOCR encoder + XLM-R decoder
  2. TrOCR encoder + IndicBART
  3. TrOCR encoder + repaired mT5-small setup
  4. ByT5 experiment only after the above

That ranking is based on current Hugging Face implementation guidance and public issue history. The T5 decoder-only route is the least standard of the four. (Hugging Face Forums)

For your final end goal: complex documents

This part matters a lot.

Your target is not just Hindi recognition. It is handwritten + printed Hindi in complex documents. The public TrOCR model card and discussion history strongly suggest that the handwritten checkpoint is best on single text-line inputs, and users doing full-page OCR typically detect or crop regions first, then run TrOCR on those crops. (Hugging Face)

So I would not design the final system as “single recognizer eats full page.” I would design it as:

  1. text-region detection
  2. line grouping or crop extraction
  3. Hindi recognizer on each crop
  4. merge results

For more document-native approaches, Donut is worth tracking because it is an OCR-free document model, but that is a different design choice from a recognizer-focused OCR pipeline. (Hugging Face)

My blunt recommendation

For your case, I would do this next:

Path A. Repair your current experiment

  • switch to trocr-small-stage1
  • switch to mt5-small
  • explicit decoder inputs and mask
  • freeze encoder
  • lr=1e-4
  • greedy decode
  • 1000-step one-sample overfit

Path B. If that still fails

Stop tuning repetition penalties. Move to:

  • VisionEncoderDecoderModel
  • xlm-roberta-base as decoder using RobertaForCausalLM
  • matching tokenizer for labels

That is the cleanest multilingual TrOCR path publicly documented by Hugging Face. (Hugging Face Forums)

Path C. If Hindi quality is still weak

Try IndicBART next, because it is actually designed around Indic languages rather than broad multilingual coverage. (Hugging Face)

Final answer

Your current result does not tell me “the idea is wrong.” It tells me:

  • you have a partially working pipeline,
  • your current overfit test is too unstable to trust,
  • the tokenizer is probably not the main blocker,
  • the biggest issue is the bridge + training regime,
  • and for your final use case, you should treat recognition and document layout as separate problems. (Hugging Face)

The fastest high-value move is to rebuild the one-sample test in the smaller, cleaner form above. If that passes, then your architecture is viable. If it still does not pass, switch decoder family before spending more time on hyperparameter tweaking.

Thanks a lot! @John6666

The corrections and ideas you suggested worked on the overfit test, and the model consistently started producing correct prediction on the single data sample i was trying to train it on. With your help, I don’t need to look for any new tokenizer since i understood the problem was in my wiring of encoder and decoder. Now i am trying to make improvements in this setup only to achieve some level of results.

i want to proceed with next steps that i had in my mind, even tried but i am a bit stuck and confused.

I wanted to proceed with the proof of concept with a dataset size of just 210 data samples of printed and handwritten hindi line crops and create a training class to train the model on it and see if it is able to learn something meaningful.

160-65 being training set and rest being validation set. I just want to do something here that more or less proves that the pipeline is ready for gradual increases of larger datasets to thousands.

I did try this today by creating a training class, and training full parameters which i think is around 400 million.

The results were not good, The training loss was consistently going down but i was also hoping for character error rate(CER) to go down like below 1.0 after 50 epochs but that didn’t happen, it started with around 1.2 or 1.3 , went down to 0.98 at its best, and after 50 epochs it ended up to around 1.5

I looked for ways to improve this through claude, I am depending upon claude a lot due to my own knowledge gaps. But claude is somewhat making things even more confusing.

it suggested me to freeze decoder and train encoder only. This reduced trainable parameters to around 262 million out of 431 million total but that didn’t work either. CER remained above 1 after 50 epochs.

I want to ask that

  1. is my approach of trying to train a model on just 210 data samples for a proof of concept right or correct or should i use larger dataset that has at least 1000 hindi line crops?

  2. If yes, then what scores should i aim for in the CER and WER metric? What would be acceptable CER and WER in this training to conclude that my pipeline is ready for larger dataset?

  3. One last question, i am not able to understand which layers to freeze or how exactly to train the entire thing to see some improvement in results. Could you suggest correct designing of training class?

1 Like

Well, I’m not sure if the following patch will work as-is, but for now:


For your case, the right conclusion is:

210 samples is a valid proof-of-concept dataset, but only if you treat it as a pipeline test, not a quality benchmark.
Your current notebook is close enough to keep. The main problem is no longer encoder-decoder wiring. The main problem is the training control logic and the fact that 160 training lines is far too little data for full fine-tuning of hundreds of millions of parameters. In warm-started encoder-decoder hybrids, the cross-attention bridge can be randomly initialized and needs downstream fine-tuning, and in mT5 the decoder starts from pad_token_id, which matches the direction you already fixed in your notebook. (Hugging Face)

My direct answer to your three questions

1. Is 210 samples okay?

Yes, for a proof that the pipeline is capable of learning.
No, for judging final model quality or choosing the best architecture.

A 210-line dataset can answer a narrow but useful question:

“Can this TrOCR-encoder plus mT5-decoder pipeline learn stable Hindi line recognition on real data without collapsing?”

That is a good proof-of-concept question. It is much narrower than “Is this good OCR?” and that is the correct way to use 210 samples. Public TrOCR documentation says the raw model is intended for single text-line images, so your use of line crops is aligned with the model family’s intended use. At the same time, public Indic OCR resources are much larger. IIIT-INDIC-HW-WORDS reports 872K handwritten instances across 8 Indic scripts, and recent low-resource Indic OCR work explicitly leans on synthetic data or parameter-efficient adaptation because small real datasets are not enough by themselves. (Hugging Face)

So the right framing is:

  • 210 lines = enough to validate that the training recipe is sane.
  • 1000+ lines = much better for deciding whether the model scales.
  • many thousands = where quality conclusions start to matter.

2. What CER and WER should you aim for?

For your current dataset size, use CER as the primary metric and WER as a secondary metric.

That is because CER is character-level edit distance and is smoother on tiny validation sets, while WER is harsher and noisier on short lines. Hugging Face’s CER implementation defines CER as character-level edit distance normalized by reference length, and WER is the word-level analogue based on substitutions, deletions, and insertions. Both are “lower is better,” and both can exceed 1.0 when insertions are large. So “CER below 1.0” is not a meaningful success threshold. A CER around 0.98 is still very poor. (GitHub)

These are the thresholds I would use for your proof-of-concept, and these are engineering thresholds, not official published cutoffs:

Minimum green light

  • validation CER < 0.50
  • validation WER < 0.80
  • predictions are readable
  • no repeated-character collapse
  • both printed and handwritten lines improve

Better green light

  • validation CER around 0.25 to 0.40
  • validation WER around 0.45 to 0.70

Strong green light

  • validation CER < 0.20 to 0.25
  • validation WER < 0.40 to 0.50

For 210 mixed-domain samples, I would call the pipeline “ready to scale” once you can reliably beat CER 0.5 and show visibly readable predictions on both printed and handwritten validation lines. The important word is reliably. One lucky split is not enough.

3. Which layers should you freeze, and how should the training class be designed?

For your exact architecture, freezing the decoder and training only the encoder is the wrong direction.

The decoder side contains:

  • the Hindi text generation behavior,
  • the new cross-attention bridge that connects image features to text generation,
  • and the autoregressive dynamics that are currently causing metric instability.

Hugging Face’s encoder-decoder docs explicitly say cross-attention may be randomly initialized in these hybrids and must be fine-tuned downstream. That means the most important adaptation is usually decoder-side cross-attention and output-side behavior, not encoder-only retraining. (Hugging Face)

So for your notebook, the correct first training stage is:

  • freeze the entire visual encoder
  • train enc_to_dec_proj if it exists
  • train decoder cross-attention
  • train lm_head
  • train shared embeddings
  • optionally train decoder layer norms

That is the right first-stage design for your current notebook.

My thoughts after checking your training cells

Your notebook is now in a much better place than before. These are the good parts:

  • you moved to trocr-small-stage1 plus mt5-small
  • the custom wrapper is now sane enough to test
  • the overfit test already proved the bridge can learn
  • the real trainer freezes the encoder and trains decoder-side bridge/output parameters

That is the right direction.

The weak part is the trainer design, not the architecture.

The biggest trainer problem

Right now, in your notebook:

  • validation loss is computed every epoch,
  • CER/WER are only computed every 5 epochs,
  • the scheduler follows validation loss,
  • but the best-checkpoint logic follows CER.

That is a mismatch.

For seq2seq OCR, teacher-forced loss and free-generation CER can move in different directions. The model can keep lowering validation loss while actual OCR output gets worse. That is especially common in tiny-data autoregressive setups. So a scheduler driven by val_loss and checkpointing driven by CER can easily train past the best OCR model.

That is the main reason I do not trust the 50-epoch result as a true judgment of the architecture.

What I think your current CER curve actually means

You said:

  • CER started around 1.2 to 1.3
  • improved to about 0.98
  • then worsened to about 1.5

That pattern usually means:

  • the model is learning something at first,
  • then overfitting or decoding drift sets in,
  • and later epochs add insertions or repetitive garbage.

Since CER is edit-distance based, late insertions can easily push it above 1.0. So this is not “almost good but not quite.” It is “still poor overall, with a brief early improvement that was not preserved.” (GitHub)

The good news is that this pattern usually points to training-control problems, not “your architecture cannot work.”

The exact changes I would make now

1. Compute CER and WER every epoch

On a validation set of about 45 to 50 lines, the extra compute is small. The benefit is large.

In your trainer, set:

generate_every_n_epochs = 1

That alone makes your best-epoch detection much more trustworthy.

2. Use CER as the one metric that controls training

Use validation CER for all three:

  • scheduler stepping
  • best-checkpoint saving
  • early stopping

Do not split those across val_loss and CER.

For this dataset size and task, CER is the best control metric. WER is still useful, but mainly as a reporting metric.

3. Add early stopping

Do not run 50 fixed epochs on 160 training lines.

Use:

  • num_epochs = 20 or 25
  • patience = 5
  • min_delta = 0.005 on CER

The best checkpoint will probably appear earlier than epoch 50. Right now your notebook is not designed to stop there.

4. Use parameter groups, not one flat AdamW group

In your notebook, all trainable parameters currently use one LR and one weight decay. That is too blunt.

Hugging Face’s training docs note that biases and LayerNorm parameters are usually excluded from weight decay. I would go one step further and also give the bridge and cross-attention a slightly higher learning rate than the rest of the decoder-side trainable weights. (Hugging Face)

A good split is:

  • bridge and cross-attention: LR 2e-4, weight decay 0.01
  • lm_head and other trainable decoder weights: LR 1e-4, weight decay 0.01
  • biases, norms, shared embeddings: LR 1e-4, weight decay 0.0

5. Normalize text before CER and WER

This matters more for Hindi than people expect.

Before computing metrics, normalize both prediction and reference with:

  • Unicode NFC normalization
  • .strip()
  • whitespace collapse

That removes avoidable Unicode and spacing noise from the metric.

6. Check whether max_length=64 is truncating your labels

In your dataset class, you set max_length=64.

That may be too small for some Hindi lines. If targets are being truncated, the model can never predict the full reference correctly, and your metrics are capped by preprocessing rather than training.

Before the next run, print:

  • max tokenized label length
  • 95th percentile tokenized length
  • number of samples hitting max_length

If many lines are hitting 64, increase it.

7. Split metrics by printed versus handwritten

This is essential in your case.

Because your dataset mixes printed and handwritten Hindi lines, one combined CER can hide a lot. A model could improve strongly on printed lines and fail on handwritten lines, while the aggregate metric still looks “okay.”

So report:

  • overall CER/WER
  • printed-only CER/WER
  • handwritten-only CER/WER

That will tell you much more than a single global score.

The freeze schedule I recommend

This is the schedule I would actually use.

Stage A. Your first real-data training stage

Freeze:

  • entire encoder

Train:

  • enc_to_dec_proj
  • decoder cross-attention
  • lm_head
  • shared embeddings
  • decoder norms

This is the best first-stage setup for your current notebook.

Stage B. If Stage A improves but plateaus

Keep encoder frozen.

Also unfreeze:

  • the last 2 decoder blocks

That gives the text side more flexibility without exploding trainable parameters.

Stage C. Only after 1000+ samples

Only then unfreeze:

  • the top 2 encoder blocks
  • at a much smaller LR, like 1e-5

Do not full-fine-tune the encoder now.

What I would not do

Do not freeze the decoder and train encoder only. That freezes exactly the side that needs to adapt to Hindi output and to the new cross-attention bridge. (Hugging Face)

Whether you should use 210 or 1000 next

My answer is:

  • keep the 210-sample experiment
  • but use it only as a training-recipe validation run
  • do not use it to choose the final model design

Once the trainer is fixed, if you can get below about CER 0.5 and the outputs look readable on both printed and handwritten lines, then I would move to 1000+ line crops immediately.

If, after fixing the trainer, CER still stays around 1.0, I would not scale yet. I would first try a parameter-efficient adaptation path, not full fine-tuning.

That recommendation is aligned with recent low-resource OCR work. Nayana uses LoRA to adapt OCR models across low-resource Indic languages, including Hindi, and the original LoRA paper explains why freezing most pretrained weights and training a small adapter can be more stable and much cheaper than full fine-tuning. (ACL Anthology)

What result would convince me that your pipeline is ready

For this specific proof-of-concept, I would call it “ready for larger datasets” if all of these are true:

  • validation CER is below 0.50
  • validation WER is below 0.80
  • outputs are readable and mostly non-repetitive
  • printed and handwritten subsets both improve
  • the best checkpoint is found early and can be restored
  • the result is stable across at least 2 different random splits

That last point matters because your current dataset is tiny. One split can flatter or punish you.

My honest bottom line

Your current case looks like this:

  • The architecture is no longer the main problem.
  • The trainer is now the main problem.
  • 210 samples is enough to validate the pipeline, but not enough to judge model quality.
  • Full fine-tuning ~400M parameters on 160 training lines is the wrong experiment.
  • Freezing the decoder and training the encoder only is the wrong direction for this model.
  • A staged frozen-encoder, decoder-side adaptation strategy is the correct next step. (Hugging Face)

So my advice is:

  1. keep the current architecture,
  2. fix the trainer,
  3. rerun the 210-line proof with CER-driven control, early stopping, and per-domain metrics,
  4. then move to 1000+ samples only after you get a stable result.

Here is the minimal patch I would make to your last two cells only.

This patch keeps your current model and dataset flow intact. It changes the trainer so that CER becomes the main control metric, generation runs every epoch, the frozen-encoder / decoder-side adaptation strategy stays in place, and AdamW uses a more standard weight-decay split for norm and bias parameters. That matches how warm-started vision-encoder-decoder hybrids typically need decoder-side fine-tuning, how mT5 expects decoder starts from pad_token_id, and how Hugging Face examples group AdamW parameters. (Hugging Face)

It also fixes the specific failure mode your current trainer has: loss can keep improving while free-generation OCR gets worse. CER is a character-level edit-distance metric where lower is better, and insertion-heavy outputs can drive it above 1.0, so it is the better signal for checkpointing and early stopping in your setup. (GitHub)

Replace Cell 14 with this

import os
import re
import random
import unicodedata
import torch
import torch.nn as nn
from jiwer import wer as compute_wer_score


def normalize_text(text: str) -> str:
    text = unicodedata.normalize("NFC", text)
    text = text.strip()
    text = re.sub(r"\s+", " ", text)
    return text


class HindiOCRTrainer:

    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        tokenizer,
        device,
        output_dir,
        num_epochs=20,
        learning_rate=1e-4,
        grad_clip_norm=1.0,
        patience=5,
        min_delta=0.005,
        max_new_tokens=64,
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.tokenizer = tokenizer
        self.device = device
        self.output_dir = output_dir
        self.num_epochs = num_epochs
        self.grad_clip_norm = grad_clip_norm
        self.patience = patience
        self.min_delta = min_delta
        self.max_new_tokens = max_new_tokens

        os.makedirs(output_dir, exist_ok=True)

        # Restore dropout after single-sample overfit test
        dropout_rate = model.mt5.config.dropout_rate
        for m in model.modules():
            if isinstance(m, nn.Dropout):
                m.p = dropout_rate

        # -----------------------------
        # Freeze strategy: Stage A
        # -----------------------------
        # Freeze full encoder
        for p in model.encoder.parameters():
            p.requires_grad = False

        # Train decoder cross-attention + output-side params
        for name, p in model.mt5.named_parameters():
            p.requires_grad = (
                ("EncDecAttention" in name) or
                ("lm_head" in name) or
                ("shared" in name) or
                ("layer_norm" in name)
            )

        # Train encoder->decoder projection if present
        if model.enc_to_dec_proj is not None:
            for p in model.enc_to_dec_proj.parameters():
                p.requires_grad = True

        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        print(f"Trainable params: {trainable:,} / {total:,}")

        # -----------------------------
        # Optimizer with param groups
        # -----------------------------
        bridge_params = []
        decay_params = []
        no_decay_params = []

        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue

            if "enc_to_dec_proj" in name or "EncDecAttention" in name:
                bridge_params.append(p)
            elif any(x in name for x in ["bias", "LayerNorm.weight", "layer_norm.weight", "shared"]):
                no_decay_params.append(p)
            else:
                decay_params.append(p)

        optimizer_groups = []
        if bridge_params:
            optimizer_groups.append({
                "params": bridge_params,
                "lr": 2e-4,
                "weight_decay": 0.01,
            })
        if decay_params:
            optimizer_groups.append({
                "params": decay_params,
                "lr": learning_rate,
                "weight_decay": 0.01,
            })
        if no_decay_params:
            optimizer_groups.append({
                "params": no_decay_params,
                "lr": learning_rate,
                "weight_decay": 0.0,
            })

        self.optimizer = torch.optim.AdamW(optimizer_groups)

        # CER is the control signal, not val_loss
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            factor=0.5,
            patience=2,
            min_lr=1e-6,
        )

        self.best_cer = float("inf")
        self.best_wer = float("inf")
        self.best_checkpoint_path = os.path.join(self.output_dir, "best_model.pt")
        self.training_log = []
        self.bad_epochs = 0

    # ============================================================
    # TRAIN
    # ============================================================
    def _train_one_epoch(self):
        self.model.train()
        total_loss = 0.0

        for batch in self.train_loader:
            pixel_values = batch["pixel_values"].to(self.device)
            labels = batch["labels"].to(self.device)

            self.optimizer.zero_grad()

            outputs = self.model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss / len(self.train_loader)

    # ============================================================
    # VALIDATE
    # ============================================================
    def _validate(self, epoch):
        self.model.eval()
        total_loss = 0.0
        all_predictions = []
        all_ground_truths = []

        with torch.no_grad():
            for batch in self.val_loader:
                pixel_values = batch["pixel_values"].to(self.device)
                labels = batch["labels"].to(self.device)

                outputs = self.model(pixel_values=pixel_values, labels=labels)
                total_loss += outputs.loss.item()

                generated_ids = self.model.generate(
                    pixel_values=pixel_values,
                    max_new_tokens=self.max_new_tokens,
                    num_beams=1,
                    do_sample=False,
                )

                clean_labels = labels.clone()
                clean_labels[clean_labels == -100] = self.tokenizer.pad_token_id

                pred_texts = self.tokenizer.batch_decode(
                    generated_ids, skip_special_tokens=True
                )
                gt_texts = self.tokenizer.batch_decode(
                    clean_labels, skip_special_tokens=True
                )

                pred_texts = [normalize_text(x) for x in pred_texts]
                gt_texts = [normalize_text(x) for x in gt_texts]

                all_predictions.extend(pred_texts)
                all_ground_truths.extend(gt_texts)

        avg_val_loss = total_loss / len(self.val_loader)
        cer = self._compute_cer(all_predictions, all_ground_truths)
        wer = compute_wer_score(all_ground_truths, all_predictions)
        exact_match = sum(p == g for p, g in zip(all_predictions, all_ground_truths)) / max(1, len(all_predictions))

        return avg_val_loss, cer, wer, exact_match, all_predictions, all_ground_truths

    # ============================================================
    # CER
    # ============================================================
    def _compute_cer(self, predictions, ground_truths):
        total_edits = 0
        total_chars = 0

        for pred, gt in zip(predictions, ground_truths):
            total_edits += self._edit_distance(pred, gt)
            total_chars += len(gt)

        return 0.0 if total_chars == 0 else total_edits / total_chars

    def _edit_distance(self, s1, s2):
        m, n = len(s1), len(s2)
        dp = list(range(n + 1))
        for i in range(1, m + 1):
            prev = dp[0]
            dp[0] = i
            for j in range(1, n + 1):
                temp = dp[j]
                if s1[i - 1] == s2[j - 1]:
                    dp[j] = prev
                else:
                    dp[j] = 1 + min(prev, dp[j], dp[j - 1])
                prev = temp
        return dp[n]

    # ============================================================
    # CHECKPOINT
    # ============================================================
    def _save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            "epoch": epoch,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict(),
            "best_cer": self.best_cer,
            "best_wer": self.best_wer,
            "training_log": self.training_log,
        }

        if is_best:
            torch.save(checkpoint, self.best_checkpoint_path)
            print(f"  ✅ Best model saved → {self.best_checkpoint_path}")

        if epoch % 5 == 0:
            periodic_path = os.path.join(self.output_dir, f"checkpoint_epoch_{epoch}.pt")
            torch.save(checkpoint, periodic_path)
            print(f"  💾 Periodic checkpoint saved → {periodic_path}")

    # ============================================================
    # MAIN LOOP
    # ============================================================
    def train(self):
        print("=" * 60)
        print(f"Starting training for {self.num_epochs} epochs")
        print("Generation metrics computed every epoch")
        print("Primary control metric: CER")
        print("Target CER for PoC: < 0.50")
        print("=" * 60)

        for epoch in range(1, self.num_epochs + 1):
            avg_train_loss = self._train_one_epoch()
            avg_val_loss, cer, wer, exact_match, predictions, ground_truths = self._validate(epoch)

            # Scheduler follows CER, not val_loss
            self.scheduler.step(cer)
            current_lr = self.optimizer.param_groups[0]["lr"]

            log_entry = {
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss,
                "cer": cer,
                "wer": wer,
                "exact_match": exact_match,
                "lr": current_lr,
            }
            self.training_log.append(log_entry)

            print(
                f"Epoch {epoch:>2}/{self.num_epochs} | "
                f"Train Loss: {avg_train_loss:.4f} | "
                f"Val Loss: {avg_val_loss:.4f} | "
                f"CER: {cer:.4f} | "
                f"WER: {wer:.4f} | "
                f"EM: {exact_match:.3f} | "
                f"LR: {current_lr:.2e}"
            )

            # Qualitative preview
            print("  Sample predictions:")
            indices = random.sample(range(len(predictions)), min(3, len(predictions)))
            for i in indices:
                print(f"    GT:   '{ground_truths[i]}'")
                print(f"    Pred: '{predictions[i]}'")

            improved = cer < (self.best_cer - self.min_delta)

            if improved:
                self.best_cer = cer
                self.best_wer = wer
                self.bad_epochs = 0
                self._save_checkpoint(epoch, is_best=True)
                print(f"  🎯 New best CER: {cer:.4f}")
                if cer < 0.50:
                    print("  ✅ TARGET REACHED: CER < 0.50")
            else:
                self.bad_epochs += 1
                print(f"  No CER improvement. Patience: {self.bad_epochs}/{self.patience}")

            self._save_checkpoint(epoch, is_best=False)

            if self.bad_epochs >= self.patience:
                print("  ⏹️ Early stopping triggered.")
                break

        print("\n" + "=" * 60)
        print("Training complete.")
        print(f"Best CER: {self.best_cer:.4f}")
        print(f"Best WER: {self.best_wer:.4f}")

        if os.path.exists(self.best_checkpoint_path):
            checkpoint = torch.load(self.best_checkpoint_path, map_location=self.device)
            self.model.load_state_dict(checkpoint["model_state"])
            print("✅ Restored best checkpoint into model.")

        return self.training_log

    # ============================================================
    # OPTIONAL RESUME
    # ============================================================
    def load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state"])
        self.best_cer = checkpoint["best_cer"]
        self.best_wer = checkpoint["best_wer"]
        self.training_log = checkpoint.get("training_log", [])
        start_epoch = checkpoint["epoch"] + 1
        print(f"✅ Resumed from epoch {checkpoint['epoch']} | Best CER so far: {self.best_cer:.4f}")
        return start_epoch

Replace Cell 15 with this

trainer = HindiOCRTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    tokenizer=tokenizer,
    device=device,
    output_dir="/content/drive/MyDrive/trocr-checkpoints",
    num_epochs=20,
    learning_rate=1e-4,
    grad_clip_norm=1.0,
    patience=5,
    min_delta=0.005,
    max_new_tokens=64,
)

training_log = trainer.train()

What changed, and why

The smallest important changes are these:

  • CER/WER now run every epoch, so the best OCR checkpoint cannot be skipped between 5-epoch intervals.
  • Scheduler now follows CER, not validation loss.
  • Early stopping stops the run once CER stops improving.
  • Parameter groups give the bridge and cross-attention a slightly higher LR, while excluding norm and bias-style parameters from weight decay in the usual Hugging Face pattern. (GitHub)
  • Text normalization is applied before metrics, which matters for Unicode-heavy scripts like Devanagari.
  • The class still keeps your current frozen encoder + decoder-side adaptation strategy, which is the right first stage for a warm-started vision-text hybrid where decoder-side cross-attention is the newly learned bridge. (Hugging Face)

One small check before you rerun

Your dataset cell still uses max_length=64 for labels. Before the next run, quickly inspect how many targets are hitting that cap. If a noticeable fraction of your Hindi lines are truncated at 64 tokens, raise it first. Otherwise the trainer can improve while the metric ceiling stays artificially low.

Use this once, anywhere after dataset creation:

lengths = []
for i in range(len(train_dataset)):
    ids = train_dataset[i]["labels"]
    valid = (ids != -100).sum().item()
    lengths.append(valid)

print("max label tokens:", max(lengths))
print("p95 label tokens:", sorted(lengths)[int(0.95 * len(lengths))])
print("num hitting 64:", sum(x >= 64 for x in lengths))

What I would expect after this patch

If the pipeline is healthy, you should see this pattern:

  • training loss goes down,
  • validation CER improves in the first several epochs,
  • the best checkpoint appears before the final epoch,
  • and early stopping restores that checkpoint instead of letting later degradation define the run.

If after this patch CER still stays around 1.0 or worse, I would not jump to “different tokenizer” or “different architecture.” I would first test one of these two moves:

  • keep the same freeze strategy and move to 1000+ line crops, or
  • keep the same data and switch to LoRA-style adaptation on the decoder-side trainable blocks, which is a common low-resource adaptation path in recent Indic OCR work. (ACL Anthology)

hi @John6666 ,

I tried the changes you suggested, but noticed the results were still not showing progress. There are lots of repetitions, no matter what.

Epoch 3/20 | Train Loss: 5.8161 | Val Loss: 5.2420 | CER: 1.6543 | WER: 2.4012 | EM: 0.000 | LR: 2.00e-04
Sample predictions:
GT: ‘आवश्यकताओं को दृष्टिगत रखते हुए यथावश्यक संशोधन किए गए हैं। प्रत्येक संस्करण में निदेशालय का’
Pred: ‘<extra_id_0> के लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए’
GT: ‘प्रागैतिहासिक काल के शिलालेखों, भित्तिचित्रों को देखकर यह अनुमान करना सहज है’
Pred: ‘<extra_id_0> के लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए और और और और और और और और और और और और और’
GT: ‘ब्राह्मी लिपि और उससे प्रसूत देवनागरी’
Pred: ‘<extra_id_0>, लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए के लिए लिए के लिए लिए के लिए लिए के लिए’
No CER improvement. Patience: 2/5
Epoch 4/20 | Train Loss: 5.4570 | Val Loss: 4.8928 | CER: 1.1956 | WER: 1.8109 | EM: 0.000 | LR: 1.00e-04
Sample predictions:
GT: ‘निम्नलिखित प्रत्येक पर लगभग 150 शब्दों में टिप्पणियाँ लिखिए :’
Pred: ‘<extra_id_0>, और और और और और और और और और और और और और और और और और और और और और और और और और और और और और और और’
GT: ‘व्यवस्था किसी अन्य लिपि में दुर्लभ है। देवनागरी लिपि में पर्याप्त ध्वनि चिह्न होने के कारण’
Pred: ‘<extra_id_0>, लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए लिए और’
GT: ‘तकनीकी आरेखण पर अक्षरांकन’
Pred: ‘<extra_id_0>,’
No CER improvement. Patience: 3/5
Epoch 5/20 | Train Loss: 5.0633 | Val Loss: 4.7285 | CER: 1.3933 | WER: 2.1244 | EM: 0.000 | LR: 1.00e-04
Sample predictions:
GT: ‘और परिष्कृत रूप में प्रस्तुत किया जाए। इस संस्करण को प्रकाशित करते हुए भी हमारे मन में यही भाव है।’
Pred: ‘<extra_id_0> के लिए और और और और और और और और और और और और और और और और और और और और और और और और और और और और और और’
GT: ‘प्रो. सदानंद प्रसाद गुप्त, गोरखपुर, उ.प्र.’
Pred: ‘<extra_id_0>क,’
GT: ‘बढ़ते हुए इस संस्करण में दक्षिण के साथ-साथ कई भारतीय भाषाओं और हिंदी की अधिकांश बोलियों’
Pred: ‘<extra_id_0> लिए और और और और और और और और और और और और और और और और और और और और और और और और और और और और और और’
No CER improvement. Patience: 4/5
:floppy_disk: Periodic checkpoint saved → /content/drive/MyDrive/trocr-checkpoints/checkpoint_epoch_5.pt
Epoch 6/20 | Train Loss: 4.9550 | Val Loss: 4.6576 | CER: 1.3935 | WER: 1.8901 | EM: 0.000 | LR: 1.00e-04
Sample predictions:
GT: ‘कंप्यूटर और देवनागरी का तालमेल’
Pred: ‘<extra_id_0> के और, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली, दिल्ली’
GT: ‘उत्कंठा रही है। किसी सतह पर अपने विचार लिखने या दर्ज करने के प्राचीनतम प्रमाण’
Pred: ‘<extra_id_0> चंद्रन,’
GT: ‘केवल भारत की संपर्क भाषा है, बल्कि वैश्विक परिदृश्य में अपनी स्वीकार्यता और तकनीकी सुदृढ़ता के’
Pred: ‘<extra_id_0>, और और और और और और और और और और और और और और और और और और और और और और और और और और और और और और और’
No CER improvement. Patience: 5/5
:stop_button: Early stopping triggered.

============================================================
Training complete.
Best CER: 1.0663
Best WER: 2.6014

Also in the overfit test, i noticed one thing, that sometimes model sticks to the target text very well. Majority of the steps give correct output

But sometimes at other target texts, it performs very badly. My concern is particularly regarding the occurrence of <extra_id_0>..

I have tried two freezing strategies in the overfit test. The unpredictable results appear in both.
one being this:

for p in model.encoder.parameters():

p.requires_grad = False

for name, p in model.mt5.named_parameters():

p.requires_grad = (

(“EncDecAttention” in name) or

(“lm_head” in name) or

(“shared” in name)

)

if model.enc_to_dec_proj is not None:

for p in model.enc_to_dec_proj.parameters():

p.requires_grad = True

and the other being this one:

for p in model.encoder.parameters():

# p.requires_grad = False

# for name, p in model.mt5.named_parameters():

# p.requires_grad = (

# (“EncDecAttention” in name) or

# (“DenseReluDense” in name) or # ← add this

# (“lm_head” in name)

# # (“shared” in name)

# )

# if model.enc_to_dec_proj is not None:

# for p in model.enc_to_dec_proj.parameters():

# p.requires_grad = True

can you help me understand why this is happening? sometimes the model fits well, sometimes it doesn’t. claude says that its the random initialization of weights which is causing the issue.

1 Like

Hmm… Hypothesis that the problem is occurring on the decoder side:


My overall conclusion

I do not think you should change the tokenizer or abandon the current TrOCR-encoder + mT5-decoder setup yet.

I think the current evidence says something more specific:

  • the architecture can work, because some targets overfit correctly;
  • the decoder is still under-adapted on harder lines;
  • and the tokens like <extra_id_0> are not random accidents — they are a very specific sign that the T5-family decoder is falling back to its pretraining behavior when OCR grounding is weak. T5-family tokenizers include extra_ids special tokens, and the original T5 pretraining objective uses sentinel tokens as part of span corruption. (huggingface.co)

So my main recommendation is:

Keep the current setup, but change how you adapt and evaluate the decoder.


Why <extra_id_0> appears at all

This is the first thing to understand.

<extra_id_0> is a built-in special token in the T5 and mT5 tokenizer family. It is not some random OCR artifact. T5 was pretrained with a denoising objective that literally teaches the model to emit sentinel tokens like <extra_id_0> when reconstructing masked spans. That means these tokens have very strong pretrained priors. (arxiv.org)

So when your OCR model is uncertain, what can happen is:

  1. the image signal is not strong enough to dominate,
  2. the decoder falls back to familiar pretrained behavior,
  3. sentinel tokens and repetitive continuations leak into generation.

That is why your outputs look like:

  • <extra_id_0>
  • then repeated “लिए”
  • then repeated “और”
  • and other locally high-probability continuations

This is a decoder grounding problem, not a token-coverage problem.


Why some lines fit and others fail

This is the second key idea.

If the only issue were random initialization, you would mostly see run-to-run differences:

  • one run works,
  • another run does not.

But what you are seeing is also line-to-line variation:

  • some target texts overfit nicely,
  • others collapse badly.

That means there are multiple causes at once.

Cause 1: random bridge initialization

Hugging Face’s encoder-decoder docs explicitly note that in warm-started hybrids, the decoder-side cross-attention can be randomly initialized and must be fine-tuned downstream. So yes, some instability is expected. (huggingface.co)

Cause 2: target difficulty is uneven

Some lines are easier:

  • shorter,
  • cleaner,
  • more common vocabulary,
  • fewer punctuation marks,
  • easier crops,
  • more printed than handwritten.

Some are harder:

  • longer,
  • more punctuation,
  • noisier handwriting,
  • denser ligatures,
  • rarer word combinations.

The hard lines require stronger and more stable image grounding. So they expose decoder weakness faster.

Cause 3: your current trainable slice is still too narrow

This is the main practical issue.

Your two freezing strategies both let the model learn some bridge behavior, but they do not give the decoder enough freedom to fully reshape sequence generation for hard OCR lines.

That is why the model can sometimes stick closely to the target and sometimes fail badly.

So my interpretation is:

random initialization contributes, but the bigger story is under-adaptation of the decoder in a tiny-data regime.


Why your two freezing strategies behave this way

Strategy 1

Train only:

  • EncDecAttention
  • lm_head
  • shared
  • projection

This helps the model learn:

  • how to inject image features into the decoder,
  • and how to map decoder states into output tokens.

That is often enough for easy examples.

But it does not fully change the decoder’s internal sequence dynamics.

So if the image signal is weak, the decoder still falls back to pretrained behavior.

Strategy 2

Add:

  • DenseReluDense

This is broader and better than Strategy 1.

But it still leaves other important parts constrained, especially self-attention-driven sequence behavior.

So it can still fail on harder examples.

That is why both strategies can show “sometimes good, sometimes terrible” behavior.

They are not wrong. They are just not broad enough yet for the hard lines.


My recommended solutions, in order

Solution 1. Keep the current architecture

Do not switch tokenizer.
Do not switch away from mT5 yet.
Do not switch away from the TrOCR encoder yet.

Reason:

  • one-sample overfit success proves the wiring can work;
  • <extra_id_0> means decoder fallback, not missing Hindi token support. (huggingface.co)

This is the highest-confidence recommendation.


Solution 2. Use a broader decoder-side adaptation strategy

This is the most important practical change.

Recommended next freeze schedule

Freeze:

  • entire encoder

Train:

  • enc_to_dec_proj
  • all EncDecAttention layers
  • lm_head
  • shared
  • all parameters in the last 2 decoder blocks

This is better than both of your current strategies because it gives the decoder more freedom to change:

  • sequence behavior,
  • grounding behavior,
  • and output token dynamics.

I would use this as the next main training strategy.

Why this makes sense:

  • the encoder already gives usable image features;
  • the fragile part is still the decoder-side bridge and generation;
  • Hugging Face’s docs already point to cross-attention as the new component that often needs fine-tuning in warm-started hybrids. (huggingface.co)

What I would not do

Do not unfreeze the encoder yet.

That is too early for your data size and not where the failure signal is pointing.


Solution 3. Suppress sentinel tokens during validation and inference

This is a very useful guardrail.

Hugging Face generation utilities support bad_words_ids, which lets you block specific tokens or token sequences during generation. Since <extra_id_n> tokens should never be valid OCR output for your task, you can suppress them during validation and inference. (huggingface.co)

Example idea:

extra_tokens = [f"<extra_id_{i}>" for i in range(100)]
bad_words_ids = tokenizer(extra_tokens, add_special_tokens=False).input_ids

generated_ids = model.generate(
    pixel_values=pixel_values,
    max_new_tokens=max_new_tokens,
    num_beams=1,
    do_sample=False,
    bad_words_ids=bad_words_ids,
)

Important caution:

  • this is not the real fix,
  • it is a guardrail.

It prevents the most obviously invalid decoder fallback behavior from polluting your evaluation, while you keep working on the actual training problem.


Solution 4. Split your tests into easy lines and hard lines

Right now your model feels “unpredictable” because you are mentally averaging together different difficulty levels.

Do this instead:

Easy probe set

Use lines that are:

  • shorter,
  • cleaner,
  • more printed,
  • less punctuation-heavy,
  • more common vocabulary.

Hard probe set

Use lines that are:

  • longer,
  • more punctuation-heavy,
  • noisier handwriting,
  • more complex Devanagari forms,
  • more unusual vocabulary.

Then run the same overfit test on both.

This will tell you much more than one mixed impression.

If easy lines fit but hard lines do not, then the explanation is not “just random init.”
It is:

  • random init,
  • plus hard-target difficulty,
  • plus decoder under-adaptation.

Solution 5. Add three diagnostics

These three diagnostics will make your debugging much clearer.

A. Sentinel-token rate

Track how often predictions contain <extra_id_0> or any <extra_id_n>.

This tells you whether the decoder is still falling back to T5 pretraining behavior.

B. Length ratio

Track:

  • len(prediction) / len(reference)

If this ratio explodes, repetition and EOS failure are dominating.

C. Target token length

Track tokenized target length for each line.

Hard examples often cluster here.

These three numbers will be more informative than loss alone.


Solution 6. Tighten generation length

A flat max_new_tokens=64 is probably too blunt.

My recommendation is:

  • compute the 95th percentile target token length in your dataset,
  • then set max_new_tokens = p95 + 4.

Why:

  • long ceilings give unstable models more room to loop,
  • shorter, data-driven ceilings reduce runaway repetition.

This is a practical recommendation based on the failure pattern you are seeing.


Solution 7. Move to LoRA if the above still fails

If the broader decoder-side adaptation still gives unstable behavior, my next recommendation is:

  • LoRA on cross-attention
  • plus LoRA on the last 2 decoder blocks

This is not just a generic modern preference. Recent low-resource Indic OCR work uses LoRA-style parameter-efficient adaptation, and the original LoRA paper explains why adapting only a small low-rank slice is often more stable and much cheaper than broad full fine-tuning. (aclanthology.org)

So if the broader partial fine-tuning still collapses, I would move to LoRA before touching the encoder.


Solution 8. Keep 210 samples for debugging, but move to 1000+ once decoder behavior stabilizes

Your 210-sample dataset is still useful. It is a good debugging set.

But it is not enough to expect stable OCR behavior from a large hybrid model across both printed and handwritten Hindi.

That is why recent low-resource Indic OCR work uses PEFT and synthetic data, and why public Indic handwriting resources are much larger. This is not a sign that your approach is bad; it is a sign that your current data regime is very small. (cvit.iiit.ac.in)

So my recommendation is:

  • keep the 210 samples for trainer and decoder debugging,
  • move to 1000+ line crops once the decoder stops collapsing into sentinels and repetitions.

My final recommendation stack

If I compress everything into the clearest sequence, it is this:

First

Keep the architecture.

Second

Train:

  • encoder frozen,
  • all cross-attention trainable,
  • last 2 decoder blocks trainable,
  • lm_head,
  • shared,
  • projection.

Third

Suppress <extra_id_n> tokens during validation and inference with bad_words_ids.

Fourth

Use easy vs hard probe sets and log:

  • sentinel-token rate,
  • length ratio,
  • target token length.

Fifth

If instability persists, move to LoRA on decoder-side blocks.

Sixth

Scale to 1000+ line crops once decoder behavior becomes sane.


My one-sentence summary

Your problem is not “wrong tokenizer.”
It is decoder fallback to T5 priors plus under-adapted generation on a tiny mixed dataset.

So my recommended solution is:

Keep the current setup, broaden decoder-side adaptation to the last 2 decoder blocks plus all cross-attention, suppress sentinel tokens during validation, and use LoRA next if the decoder is still unstable.

That is the path I think best fits what you are seeing now.