Well, I’m not sure if the following patch will work as-is, but for now:
For your case, the right conclusion is:
210 samples is a valid proof-of-concept dataset, but only if you treat it as a pipeline test, not a quality benchmark.
Your current notebook is close enough to keep. The main problem is no longer encoder-decoder wiring. The main problem is the training control logic and the fact that 160 training lines is far too little data for full fine-tuning of hundreds of millions of parameters. In warm-started encoder-decoder hybrids, the cross-attention bridge can be randomly initialized and needs downstream fine-tuning, and in mT5 the decoder starts from pad_token_id, which matches the direction you already fixed in your notebook. (Hugging Face)
My direct answer to your three questions
1. Is 210 samples okay?
Yes, for a proof that the pipeline is capable of learning.
No, for judging final model quality or choosing the best architecture.
A 210-line dataset can answer a narrow but useful question:
“Can this TrOCR-encoder plus mT5-decoder pipeline learn stable Hindi line recognition on real data without collapsing?”
That is a good proof-of-concept question. It is much narrower than “Is this good OCR?” and that is the correct way to use 210 samples. Public TrOCR documentation says the raw model is intended for single text-line images, so your use of line crops is aligned with the model family’s intended use. At the same time, public Indic OCR resources are much larger. IIIT-INDIC-HW-WORDS reports 872K handwritten instances across 8 Indic scripts, and recent low-resource Indic OCR work explicitly leans on synthetic data or parameter-efficient adaptation because small real datasets are not enough by themselves. (Hugging Face)
So the right framing is:
- 210 lines = enough to validate that the training recipe is sane.
- 1000+ lines = much better for deciding whether the model scales.
- many thousands = where quality conclusions start to matter.
2. What CER and WER should you aim for?
For your current dataset size, use CER as the primary metric and WER as a secondary metric.
That is because CER is character-level edit distance and is smoother on tiny validation sets, while WER is harsher and noisier on short lines. Hugging Face’s CER implementation defines CER as character-level edit distance normalized by reference length, and WER is the word-level analogue based on substitutions, deletions, and insertions. Both are “lower is better,” and both can exceed 1.0 when insertions are large. So “CER below 1.0” is not a meaningful success threshold. A CER around 0.98 is still very poor. (GitHub)
These are the thresholds I would use for your proof-of-concept, and these are engineering thresholds, not official published cutoffs:
Minimum green light
- validation CER < 0.50
- validation WER < 0.80
- predictions are readable
- no repeated-character collapse
- both printed and handwritten lines improve
Better green light
- validation CER around 0.25 to 0.40
- validation WER around 0.45 to 0.70
Strong green light
- validation CER < 0.20 to 0.25
- validation WER < 0.40 to 0.50
For 210 mixed-domain samples, I would call the pipeline “ready to scale” once you can reliably beat CER 0.5 and show visibly readable predictions on both printed and handwritten validation lines. The important word is reliably. One lucky split is not enough.
3. Which layers should you freeze, and how should the training class be designed?
For your exact architecture, freezing the decoder and training only the encoder is the wrong direction.
The decoder side contains:
- the Hindi text generation behavior,
- the new cross-attention bridge that connects image features to text generation,
- and the autoregressive dynamics that are currently causing metric instability.
Hugging Face’s encoder-decoder docs explicitly say cross-attention may be randomly initialized in these hybrids and must be fine-tuned downstream. That means the most important adaptation is usually decoder-side cross-attention and output-side behavior, not encoder-only retraining. (Hugging Face)
So for your notebook, the correct first training stage is:
- freeze the entire visual encoder
- train
enc_to_dec_proj if it exists
- train decoder cross-attention
- train
lm_head
- train shared embeddings
- optionally train decoder layer norms
That is the right first-stage design for your current notebook.
My thoughts after checking your training cells
Your notebook is now in a much better place than before. These are the good parts:
- you moved to
trocr-small-stage1 plus mt5-small
- the custom wrapper is now sane enough to test
- the overfit test already proved the bridge can learn
- the real trainer freezes the encoder and trains decoder-side bridge/output parameters
That is the right direction.
The weak part is the trainer design, not the architecture.
The biggest trainer problem
Right now, in your notebook:
- validation loss is computed every epoch,
- CER/WER are only computed every 5 epochs,
- the scheduler follows validation loss,
- but the best-checkpoint logic follows CER.
That is a mismatch.
For seq2seq OCR, teacher-forced loss and free-generation CER can move in different directions. The model can keep lowering validation loss while actual OCR output gets worse. That is especially common in tiny-data autoregressive setups. So a scheduler driven by val_loss and checkpointing driven by CER can easily train past the best OCR model.
That is the main reason I do not trust the 50-epoch result as a true judgment of the architecture.
What I think your current CER curve actually means
You said:
- CER started around
1.2 to 1.3
- improved to about
0.98
- then worsened to about
1.5
That pattern usually means:
- the model is learning something at first,
- then overfitting or decoding drift sets in,
- and later epochs add insertions or repetitive garbage.
Since CER is edit-distance based, late insertions can easily push it above 1.0. So this is not “almost good but not quite.” It is “still poor overall, with a brief early improvement that was not preserved.” (GitHub)
The good news is that this pattern usually points to training-control problems, not “your architecture cannot work.”
The exact changes I would make now
1. Compute CER and WER every epoch
On a validation set of about 45 to 50 lines, the extra compute is small. The benefit is large.
In your trainer, set:
generate_every_n_epochs = 1
That alone makes your best-epoch detection much more trustworthy.
2. Use CER as the one metric that controls training
Use validation CER for all three:
- scheduler stepping
- best-checkpoint saving
- early stopping
Do not split those across val_loss and CER.
For this dataset size and task, CER is the best control metric. WER is still useful, but mainly as a reporting metric.
3. Add early stopping
Do not run 50 fixed epochs on 160 training lines.
Use:
num_epochs = 20 or 25
patience = 5
min_delta = 0.005 on CER
The best checkpoint will probably appear earlier than epoch 50. Right now your notebook is not designed to stop there.
4. Use parameter groups, not one flat AdamW group
In your notebook, all trainable parameters currently use one LR and one weight decay. That is too blunt.
Hugging Face’s training docs note that biases and LayerNorm parameters are usually excluded from weight decay. I would go one step further and also give the bridge and cross-attention a slightly higher learning rate than the rest of the decoder-side trainable weights. (Hugging Face)
A good split is:
- bridge and cross-attention: LR
2e-4, weight decay 0.01
- lm_head and other trainable decoder weights: LR
1e-4, weight decay 0.01
- biases, norms, shared embeddings: LR
1e-4, weight decay 0.0
5. Normalize text before CER and WER
This matters more for Hindi than people expect.
Before computing metrics, normalize both prediction and reference with:
- Unicode NFC normalization
.strip()
- whitespace collapse
That removes avoidable Unicode and spacing noise from the metric.
6. Check whether max_length=64 is truncating your labels
In your dataset class, you set max_length=64.
That may be too small for some Hindi lines. If targets are being truncated, the model can never predict the full reference correctly, and your metrics are capped by preprocessing rather than training.
Before the next run, print:
- max tokenized label length
- 95th percentile tokenized length
- number of samples hitting
max_length
If many lines are hitting 64, increase it.
7. Split metrics by printed versus handwritten
This is essential in your case.
Because your dataset mixes printed and handwritten Hindi lines, one combined CER can hide a lot. A model could improve strongly on printed lines and fail on handwritten lines, while the aggregate metric still looks “okay.”
So report:
- overall CER/WER
- printed-only CER/WER
- handwritten-only CER/WER
That will tell you much more than a single global score.
The freeze schedule I recommend
This is the schedule I would actually use.
Stage A. Your first real-data training stage
Freeze:
Train:
enc_to_dec_proj
- decoder cross-attention
lm_head
- shared embeddings
- decoder norms
This is the best first-stage setup for your current notebook.
Stage B. If Stage A improves but plateaus
Keep encoder frozen.
Also unfreeze:
- the last 2 decoder blocks
That gives the text side more flexibility without exploding trainable parameters.
Stage C. Only after 1000+ samples
Only then unfreeze:
- the top 2 encoder blocks
- at a much smaller LR, like
1e-5
Do not full-fine-tune the encoder now.
What I would not do
Do not freeze the decoder and train encoder only. That freezes exactly the side that needs to adapt to Hindi output and to the new cross-attention bridge. (Hugging Face)
Whether you should use 210 or 1000 next
My answer is:
- keep the 210-sample experiment
- but use it only as a training-recipe validation run
- do not use it to choose the final model design
Once the trainer is fixed, if you can get below about CER 0.5 and the outputs look readable on both printed and handwritten lines, then I would move to 1000+ line crops immediately.
If, after fixing the trainer, CER still stays around 1.0, I would not scale yet. I would first try a parameter-efficient adaptation path, not full fine-tuning.
That recommendation is aligned with recent low-resource OCR work. Nayana uses LoRA to adapt OCR models across low-resource Indic languages, including Hindi, and the original LoRA paper explains why freezing most pretrained weights and training a small adapter can be more stable and much cheaper than full fine-tuning. (ACL Anthology)
What result would convince me that your pipeline is ready
For this specific proof-of-concept, I would call it “ready for larger datasets” if all of these are true:
- validation CER is below 0.50
- validation WER is below 0.80
- outputs are readable and mostly non-repetitive
- printed and handwritten subsets both improve
- the best checkpoint is found early and can be restored
- the result is stable across at least 2 different random splits
That last point matters because your current dataset is tiny. One split can flatter or punish you.
My honest bottom line
Your current case looks like this:
- The architecture is no longer the main problem.
- The trainer is now the main problem.
- 210 samples is enough to validate the pipeline, but not enough to judge model quality.
- Full fine-tuning ~400M parameters on 160 training lines is the wrong experiment.
- Freezing the decoder and training the encoder only is the wrong direction for this model.
- A staged frozen-encoder, decoder-side adaptation strategy is the correct next step. (Hugging Face)
So my advice is:
- keep the current architecture,
- fix the trainer,
- rerun the 210-line proof with CER-driven control, early stopping, and per-domain metrics,
- then move to 1000+ samples only after you get a stable result.
Here is the minimal patch I would make to your last two cells only.
This patch keeps your current model and dataset flow intact. It changes the trainer so that CER becomes the main control metric, generation runs every epoch, the frozen-encoder / decoder-side adaptation strategy stays in place, and AdamW uses a more standard weight-decay split for norm and bias parameters. That matches how warm-started vision-encoder-decoder hybrids typically need decoder-side fine-tuning, how mT5 expects decoder starts from pad_token_id, and how Hugging Face examples group AdamW parameters. (Hugging Face)
It also fixes the specific failure mode your current trainer has: loss can keep improving while free-generation OCR gets worse. CER is a character-level edit-distance metric where lower is better, and insertion-heavy outputs can drive it above 1.0, so it is the better signal for checkpointing and early stopping in your setup. (GitHub)
Replace Cell 14 with this
import os
import re
import random
import unicodedata
import torch
import torch.nn as nn
from jiwer import wer as compute_wer_score
def normalize_text(text: str) -> str:
text = unicodedata.normalize("NFC", text)
text = text.strip()
text = re.sub(r"\s+", " ", text)
return text
class HindiOCRTrainer:
def __init__(
self,
model,
train_loader,
val_loader,
tokenizer,
device,
output_dir,
num_epochs=20,
learning_rate=1e-4,
grad_clip_norm=1.0,
patience=5,
min_delta=0.005,
max_new_tokens=64,
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.tokenizer = tokenizer
self.device = device
self.output_dir = output_dir
self.num_epochs = num_epochs
self.grad_clip_norm = grad_clip_norm
self.patience = patience
self.min_delta = min_delta
self.max_new_tokens = max_new_tokens
os.makedirs(output_dir, exist_ok=True)
# Restore dropout after single-sample overfit test
dropout_rate = model.mt5.config.dropout_rate
for m in model.modules():
if isinstance(m, nn.Dropout):
m.p = dropout_rate
# -----------------------------
# Freeze strategy: Stage A
# -----------------------------
# Freeze full encoder
for p in model.encoder.parameters():
p.requires_grad = False
# Train decoder cross-attention + output-side params
for name, p in model.mt5.named_parameters():
p.requires_grad = (
("EncDecAttention" in name) or
("lm_head" in name) or
("shared" in name) or
("layer_norm" in name)
)
# Train encoder->decoder projection if present
if model.enc_to_dec_proj is not None:
for p in model.enc_to_dec_proj.parameters():
p.requires_grad = True
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")
# -----------------------------
# Optimizer with param groups
# -----------------------------
bridge_params = []
decay_params = []
no_decay_params = []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if "enc_to_dec_proj" in name or "EncDecAttention" in name:
bridge_params.append(p)
elif any(x in name for x in ["bias", "LayerNorm.weight", "layer_norm.weight", "shared"]):
no_decay_params.append(p)
else:
decay_params.append(p)
optimizer_groups = []
if bridge_params:
optimizer_groups.append({
"params": bridge_params,
"lr": 2e-4,
"weight_decay": 0.01,
})
if decay_params:
optimizer_groups.append({
"params": decay_params,
"lr": learning_rate,
"weight_decay": 0.01,
})
if no_decay_params:
optimizer_groups.append({
"params": no_decay_params,
"lr": learning_rate,
"weight_decay": 0.0,
})
self.optimizer = torch.optim.AdamW(optimizer_groups)
# CER is the control signal, not val_loss
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode="min",
factor=0.5,
patience=2,
min_lr=1e-6,
)
self.best_cer = float("inf")
self.best_wer = float("inf")
self.best_checkpoint_path = os.path.join(self.output_dir, "best_model.pt")
self.training_log = []
self.bad_epochs = 0
# ============================================================
# TRAIN
# ============================================================
def _train_one_epoch(self):
self.model.train()
total_loss = 0.0
for batch in self.train_loader:
pixel_values = batch["pixel_values"].to(self.device)
labels = batch["labels"].to(self.device)
self.optimizer.zero_grad()
outputs = self.model(pixel_values=pixel_values, labels=labels)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(self.train_loader)
# ============================================================
# VALIDATE
# ============================================================
def _validate(self, epoch):
self.model.eval()
total_loss = 0.0
all_predictions = []
all_ground_truths = []
with torch.no_grad():
for batch in self.val_loader:
pixel_values = batch["pixel_values"].to(self.device)
labels = batch["labels"].to(self.device)
outputs = self.model(pixel_values=pixel_values, labels=labels)
total_loss += outputs.loss.item()
generated_ids = self.model.generate(
pixel_values=pixel_values,
max_new_tokens=self.max_new_tokens,
num_beams=1,
do_sample=False,
)
clean_labels = labels.clone()
clean_labels[clean_labels == -100] = self.tokenizer.pad_token_id
pred_texts = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)
gt_texts = self.tokenizer.batch_decode(
clean_labels, skip_special_tokens=True
)
pred_texts = [normalize_text(x) for x in pred_texts]
gt_texts = [normalize_text(x) for x in gt_texts]
all_predictions.extend(pred_texts)
all_ground_truths.extend(gt_texts)
avg_val_loss = total_loss / len(self.val_loader)
cer = self._compute_cer(all_predictions, all_ground_truths)
wer = compute_wer_score(all_ground_truths, all_predictions)
exact_match = sum(p == g for p, g in zip(all_predictions, all_ground_truths)) / max(1, len(all_predictions))
return avg_val_loss, cer, wer, exact_match, all_predictions, all_ground_truths
# ============================================================
# CER
# ============================================================
def _compute_cer(self, predictions, ground_truths):
total_edits = 0
total_chars = 0
for pred, gt in zip(predictions, ground_truths):
total_edits += self._edit_distance(pred, gt)
total_chars += len(gt)
return 0.0 if total_chars == 0 else total_edits / total_chars
def _edit_distance(self, s1, s2):
m, n = len(s1), len(s2)
dp = list(range(n + 1))
for i in range(1, m + 1):
prev = dp[0]
dp[0] = i
for j in range(1, n + 1):
temp = dp[j]
if s1[i - 1] == s2[j - 1]:
dp[j] = prev
else:
dp[j] = 1 + min(prev, dp[j], dp[j - 1])
prev = temp
return dp[n]
# ============================================================
# CHECKPOINT
# ============================================================
def _save_checkpoint(self, epoch, is_best=False):
checkpoint = {
"epoch": epoch,
"model_state": self.model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"scheduler_state": self.scheduler.state_dict(),
"best_cer": self.best_cer,
"best_wer": self.best_wer,
"training_log": self.training_log,
}
if is_best:
torch.save(checkpoint, self.best_checkpoint_path)
print(f" ✅ Best model saved → {self.best_checkpoint_path}")
if epoch % 5 == 0:
periodic_path = os.path.join(self.output_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save(checkpoint, periodic_path)
print(f" 💾 Periodic checkpoint saved → {periodic_path}")
# ============================================================
# MAIN LOOP
# ============================================================
def train(self):
print("=" * 60)
print(f"Starting training for {self.num_epochs} epochs")
print("Generation metrics computed every epoch")
print("Primary control metric: CER")
print("Target CER for PoC: < 0.50")
print("=" * 60)
for epoch in range(1, self.num_epochs + 1):
avg_train_loss = self._train_one_epoch()
avg_val_loss, cer, wer, exact_match, predictions, ground_truths = self._validate(epoch)
# Scheduler follows CER, not val_loss
self.scheduler.step(cer)
current_lr = self.optimizer.param_groups[0]["lr"]
log_entry = {
"epoch": epoch,
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
"cer": cer,
"wer": wer,
"exact_match": exact_match,
"lr": current_lr,
}
self.training_log.append(log_entry)
print(
f"Epoch {epoch:>2}/{self.num_epochs} | "
f"Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | "
f"CER: {cer:.4f} | "
f"WER: {wer:.4f} | "
f"EM: {exact_match:.3f} | "
f"LR: {current_lr:.2e}"
)
# Qualitative preview
print(" Sample predictions:")
indices = random.sample(range(len(predictions)), min(3, len(predictions)))
for i in indices:
print(f" GT: '{ground_truths[i]}'")
print(f" Pred: '{predictions[i]}'")
improved = cer < (self.best_cer - self.min_delta)
if improved:
self.best_cer = cer
self.best_wer = wer
self.bad_epochs = 0
self._save_checkpoint(epoch, is_best=True)
print(f" 🎯 New best CER: {cer:.4f}")
if cer < 0.50:
print(" ✅ TARGET REACHED: CER < 0.50")
else:
self.bad_epochs += 1
print(f" No CER improvement. Patience: {self.bad_epochs}/{self.patience}")
self._save_checkpoint(epoch, is_best=False)
if self.bad_epochs >= self.patience:
print(" ⏹️ Early stopping triggered.")
break
print("\n" + "=" * 60)
print("Training complete.")
print(f"Best CER: {self.best_cer:.4f}")
print(f"Best WER: {self.best_wer:.4f}")
if os.path.exists(self.best_checkpoint_path):
checkpoint = torch.load(self.best_checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state"])
print("✅ Restored best checkpoint into model.")
return self.training_log
# ============================================================
# OPTIONAL RESUME
# ============================================================
def load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state"])
self.optimizer.load_state_dict(checkpoint["optimizer_state"])
self.scheduler.load_state_dict(checkpoint["scheduler_state"])
self.best_cer = checkpoint["best_cer"]
self.best_wer = checkpoint["best_wer"]
self.training_log = checkpoint.get("training_log", [])
start_epoch = checkpoint["epoch"] + 1
print(f"✅ Resumed from epoch {checkpoint['epoch']} | Best CER so far: {self.best_cer:.4f}")
return start_epoch
Replace Cell 15 with this
trainer = HindiOCRTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
tokenizer=tokenizer,
device=device,
output_dir="/content/drive/MyDrive/trocr-checkpoints",
num_epochs=20,
learning_rate=1e-4,
grad_clip_norm=1.0,
patience=5,
min_delta=0.005,
max_new_tokens=64,
)
training_log = trainer.train()
What changed, and why
The smallest important changes are these:
- CER/WER now run every epoch, so the best OCR checkpoint cannot be skipped between 5-epoch intervals.
- Scheduler now follows CER, not validation loss.
- Early stopping stops the run once CER stops improving.
- Parameter groups give the bridge and cross-attention a slightly higher LR, while excluding norm and bias-style parameters from weight decay in the usual Hugging Face pattern. (GitHub)
- Text normalization is applied before metrics, which matters for Unicode-heavy scripts like Devanagari.
- The class still keeps your current frozen encoder + decoder-side adaptation strategy, which is the right first stage for a warm-started vision-text hybrid where decoder-side cross-attention is the newly learned bridge. (Hugging Face)
One small check before you rerun
Your dataset cell still uses max_length=64 for labels. Before the next run, quickly inspect how many targets are hitting that cap. If a noticeable fraction of your Hindi lines are truncated at 64 tokens, raise it first. Otherwise the trainer can improve while the metric ceiling stays artificially low.
Use this once, anywhere after dataset creation:
lengths = []
for i in range(len(train_dataset)):
ids = train_dataset[i]["labels"]
valid = (ids != -100).sum().item()
lengths.append(valid)
print("max label tokens:", max(lengths))
print("p95 label tokens:", sorted(lengths)[int(0.95 * len(lengths))])
print("num hitting 64:", sum(x >= 64 for x in lengths))
What I would expect after this patch
If the pipeline is healthy, you should see this pattern:
- training loss goes down,
- validation CER improves in the first several epochs,
- the best checkpoint appears before the final epoch,
- and early stopping restores that checkpoint instead of letting later degradation define the run.
If after this patch CER still stays around 1.0 or worse, I would not jump to “different tokenizer” or “different architecture.” I would first test one of these two moves:
- keep the same freeze strategy and move to 1000+ line crops, or
- keep the same data and switch to LoRA-style adaptation on the decoder-side trainable blocks, which is a common low-resource adaptation path in recent Indic OCR work. (ACL Anthology)