Would this concept model work?

It’s an mdlm with ternary bit and hybrid q8 and q4 activation and 3 bit kvcache used with block diffusion. The training code is really messy so I don’t really want to share the pytoarch training code but I’m trying to train a 1b module with 40b training token.

1 Like

Since BitNet works, I suppose it’s conceptually possible…


It can work as a research model. I would not expect the first full run to be the easiest or safest way to get the best 1B model from 40B tokens. The concept is technically plausible because each major piece has public precedent: native ternary training in BitNet, hybrid 8/4-bit activation handling plus 3-bit KV support in BitNet a4.8, strong masked diffusion language modeling in MDLM, and block diffusion for arbitrary-length semi-parallel generation with KV caching. The risk is that your design combines four hard optimizations at once, and the public literature does not yet show a mature, standard recipe for that exact stack in one training run. (arXiv)

The short answer

My judgment for your exact case is:

  • Conceptually sound: yes.
  • Likely to train at all: yes, if staged carefully.
  • Likely to be stable from step 1 with messy code: no, that is the weak point.
  • Likely to beat a simpler 1B baseline on quality-per-training-compute immediately: probably not.
  • Likely to become an interesting efficient system if you sequence the hard parts carefully: yes. (arXiv)

Why the idea is reasonable

Your design is coherent. It is trying to reduce cost in four different places:

  1. Weights with ternary values. BitNet b1.58 explicitly uses ternary {-1,0,1} weights and argues that native low-bit training can match full-precision models of similar scale and token budget. Microsoft’s later 2B4T report extends that claim to a larger open model. (arXiv)

  2. Runtime activations with a hybrid 8/4-bit path. BitNet a4.8 is the closest public match here. It does not just say “4-bit activations everywhere.” It says 4-bit for selected inputs to attention and FFN, while sparsifying and 8-bit-quantizing intermediate states to control outliers. (arXiv)

  3. Attention state memory with low-bit KV cache. KIVI shows 2-bit KV compression can preserve quality well while cutting memory a lot, QServe shows practical W4A8KV4 serving, and Google’s TurboQuant claims 3-bit KV compression without retraining or fine-tuning and without quality loss on its reported benchmarks. (arXiv)

  4. Generation order with diffusion instead of fully sequential AR decoding. MDLM shows masked diffusion language models can get much closer to AR than older diffusion NLP work, and block diffusion adds arbitrary-length generation, KV reuse, and parallel token sampling. (arXiv)

So the high-level idea is not random. It is a real systems thesis. (arXiv)

Why I am still cautious

The public evidence is favorable to each ingredient separately, but much less favorable to all of them being hard at once.

The main warning sign is low-bit attention. Attn-QAT says reliable 4-bit attention is difficult because FP4 has tiny dynamic range and attention activations are heavy-tailed. It also reports that naive “drop-in” QAT leads to instability if the backward pass assumes higher precision than the forward pass actually used. That is extremely relevant to your hybrid 8/4-bit activation design. (arXiv)

The second warning sign is diffusion-specific quantization. A recent systematic study on quantizing diffusion LLMs says dLLMs have activation outliers, that W8A8 is usually close to lossless, and that W4A4 is still hard, especially for harder tasks. Another diffusion-LLM quantization paper says dynamic masking, iterative generation, and bidirectional attention all clash with standard quantization assumptions. That is exactly the place where your model is most exposed: low-bit activations inside a diffusion-style attention stack. (arXiv)

So the architecture is plausible, but the fragile junction is very specific: diffusion + low-bit attention activations, not ternary weights alone and not KV compression alone. (arXiv)

What each ingredient means in your setup

1. MDLM plus block diffusion

This is the most novel modeling part of your stack. MDLM is now a serious baseline, not just an experiment. It showed masked diffusion can approach AR perplexity with a strong training recipe. Block diffusion then extended that line by interpolating AR and diffusion behavior, adding variable-length generation, KV caching, and parallel blockwise sampling. (arXiv)

But diffusion language modeling still carries a tax. A 2026 scaling study says masked diffusion can be made about 12% more FLOPs-efficient with a simpler cross-entropy objective, yet also argues that perplexity is not sufficient across diffusion families and that some interpolating methods have different speed-quality tradeoffs. Another controlled comparison found AR and MDLM had similar raw training throughput on the tested setup, but AR converged faster while MDLM kept improving longer. That means diffusion is viable, but it is not yet the easy default. (arXiv)

For you, that means block diffusion is a real feature, not a gimmick. But it also means you are starting from a training regime that is less forgiving than plain AR. (arXiv)

2. Ternary weights

This is the strongest part of your plan. BitNet and BitNet b1.58 are the cleanest public evidence that native ternary training can work at scale. Microsoft’s 2B4T technical report strengthens that case further. If I had to choose one part of your concept to trust the most, it would be the ternary-weight core. (arXiv)

The caution is not “ternary is fake.” The caution is that the official public BitNet stack is much more mature on inference than on open training. The official repo is an inference framework, and public requests about training code and training behavior are still visible. That suggests native ternary training is real, but not yet as operationally standardized as ordinary Transformer pretraining. (GitHub)

3. Hybrid q8 / q4 activations

This is the most important word in your description: hybrid.

That makes your idea much more plausible than “all 4-bit activations everywhere.” BitNet a4.8 is effectively telling you that selective low-bit activation handling is the viable route. It keeps some paths at 4-bit, some in sparsified 8-bit form, and frames the whole thing as a strategy to mitigate quantization errors from outlier channels. (arXiv)

The public literature around low-bit training also points toward staging rather than all-hard-mode-from-start. ParetoQ reports that, in its experiments, the best results come from doing most of training in higher precision and only a smaller final portion in QAT. That is not your exact setup, but it supports the same practical lesson: the hardest quantization should usually enter late, not dominate the entire run from the first step. (arXiv)

4. 3-bit KV cache

This part is increasingly plausible, but it is mostly an inference-side win, not a reason your pretraining will be easier. KIVI and TurboQuant are both about reducing serving memory and improving throughput, not improving base optimization. QServe makes the same general point from another angle: efficient low-bit serving depends heavily on systems co-design, not just on math. (arXiv)

So I would not use “3-bit KV cache” as part of the justification for training stability. I would treat it as a deployment feature that you want the architecture to tolerate well. (arXiv)

What 1B with 40B tokens means

For a 1B model, 40B tokens is 40 tokens per parameter. By older AR scaling heuristics, that is not absurdly low. Chinchilla’s headline example was 70B parameters trained on 1.4T tokens, which is about 20 tokens per parameter. So in plain AR terms, 40B tokens for 1B parameters is not obviously undertraining. (arXiv)

But that does not mean your setup is comfortably overprovisioned. Diffusion-language work still suggests that some diffusion families need more compute than AR to match likelihood, and scaling results show different diffusion families trade off perplexity and generation speed in nontrivial ways. So your 40B-token budget is enough for a meaningful run, but not enough to absorb many simultaneous training pathologies for free. (arXiv)

My translation of that into plain language is:

  • 40B tokens is enough to train a real 1B model.
  • 40B tokens is not enough to be careless about numerics, schedules, or ablations when you are stacking diffusion and aggressive low-bit behavior. (arXiv)

What I think will happen if you try this exactly as stated

If you turn on all difficult ingredients from the start, the most likely outcomes are:

  1. It trains, but underperforms a simpler baseline.
    This is the most probable result. The model may remain coherent and useful, but lose too much optimization headroom to match a simpler AR or safer MDLM baseline trained on the same budget. That is consistent with current diffusion scaling results and low-bit attention warnings. (arXiv)

  2. The loss looks mostly normal, then destabilizes late.
    Public BitNet issue reports and the low-precision attention literature both point to the possibility of training that appears healthy, then degrades suddenly once the model enters a more sensitive region of optimization. (GitHub)

  3. You misdiagnose engineering trouble as algorithmic failure.
    Because the public ecosystems here are still rough, a messy training codebase makes it harder to tell whether you hit a real modeling limit or just a bad kernel path, masking bug, or inconsistent attention implementation. Public issues on bd3lms and BitNet reinforce that the tooling is still not boring and mature. (GitHub)

What I would do in your case

The main recommendation

I would not run the full intended stack from step 1.

I would instead aim for this order:

  1. Get the MDLM or block-diffusion backbone stable without the hardest activation regime.
    Diffusion is already one research variable. Make that one variable first. MDLM and block diffusion both have strong public recipes, and block diffusion itself emphasizes variance reduction and data-driven schedules. (arXiv)

  2. Use ternary weights early if that is central to the thesis.
    This is the most defensible low-bit choice you have. The public evidence for native ternary training is much stronger than the evidence for full early 4-bit activation training. (arXiv)

  3. Keep activations safer early, then introduce the harder q4 path later.
    This matches the spirit of BitNet a4.8 and the general QAT evidence from ParetoQ. (arXiv)

  4. Evaluate 3-bit KV mainly as an inference layer.
    That is where the literature is strongest. (arXiv)

Why this ordering makes sense

This ordering separates the risks:

  • If the model fails before q4 activations enter, the problem is likely diffusion or ternary numerics.
  • If it only fails after q4 enters, the culprit is probably the low-bit activation path.
  • If pretraining succeeds but long-context inference quality drops, the problem is likely the KV compression layer.

That gives you information instead of a single ambiguous failure. The public literature strongly supports doing this kind of separation, because the failure modes are not all in the same place. (arXiv)

If your code is messy, this matters even more

You said the training code is messy and you do not want to share it. That is fine. It just changes the best strategy.

With clean code and heavy instrumentation, you can afford a more aggressive stack because you can localize failures quickly. With messy code, you want fewer simultaneous sources of instability. The literature does not explicitly say “messy code is bad,” but the state of the public repos and issues strongly implies that these methods are still engineering-sensitive. (GitHub)

There is also a stronger alternative route if your goal is simply “get a diffusion-style model without rewriting everything.” LLaDA’s guidelines say their backbone can be derived from an AR model by simply removing the causal mask, and DiffuLLaMA explicitly argues that training diffusion models from scratch at scale is challenging and that adapting AR models is an effective route. For a messy codebase, that is a major hint: reduce the amount of architectural novelty you inject at once. (GitHub)

My actual verdict

Here is the plain verdict.

Would the concept model work?
Yes, probably, if “work” means “can be trained into a real model and can demonstrate the intended efficiency ideas.” (arXiv)

Would I expect the first end-to-end 1B / 40B-token run, with messy code, to cleanly validate the full concept as stated?
No. The most likely blocker is not ternary weights and not 3-bit KV. It is the combination of diffusion training and low-bit attention activations, especially if q4 is active too early or too broadly. (arXiv)

What is the highest-confidence version of your idea?
A staged one: ternary core first, safer activation path first, q4 introduced later and selectively, 3-bit KV treated mostly as an inference feature, and block diffusion adopted only after the simpler masked-diffusion path is behaving. (arXiv)

What is the single sentence summary?
Your concept is plausible but over-coupled. The strongest public support is for the ternary core and KV compression, while the strongest public warnings are about low-bit activations inside attention, especially in diffusion-style models. (arXiv)

Thanks for the reply. Tbh I lied a little bit….

I trained a 100m model on about 600m tokens(only pretraining) and it got a coherent output not the best in any means but i mean at that point. And also I do have somewhat of a cleanesk code. Used claude to write most of it but with my limited knowledge it does seem okay. This is the repo GitHub - Fury7425/bitDiffusion-a4.8 · GitHub and I asked claude to write readme for me based on the conversation that leads to this thing. If you want please try reading this report and like tell me how it would work or not work.

And just ignore the lisence part. Idk why opus pug it there but I’m out of the daily limit to corrent it for the time being

1 Like

Looks like there might be a slight bug?


I reviewed the actual code path, not the README. I also sanity-checked a tiny forward/backward path locally.

The short verdict is:

This is a real model implementation, not a fake scaffold.
It can plausibly train into a coherent prototype.
I would not launch the 1B run unchanged.
The main reasons are not “ternary is impossible” or “MDLM is wrong.” The core architecture is aligned with the literature: masked diffusion language models are viable, block diffusion is a real semi-autoregressive extension with KV caching, ternary-from-scratch has precedent, and BitNet a4.8-style hybrid activation handling is the right direction. The fragile zone is still low-bit attention/activation behavior, especially when stacked on diffusion. (arXiv)

Final judgment

If you asked me, “Would this codebase probably produce a coherent 1B masked-diffusion model if I spend the compute?”, my answer is:

Probably yes, after a few fixes.

If you asked me, “Would this exact codebase, as-is, cleanly validate the whole concept and be easy to trust at 1B/40B?”, my answer is:

No. It has a solid core plus several correctness and interpretation issues.

What is solid

These parts are good enough that I would keep them.

1. The core modeling choice is valid

The model is a bidirectional denoiser with absorbing-state masking and a per-sample noise level t. That is the right family for MDLM-style training. MDLM specifically showed that simple masked discrete diffusion can be much stronger than older diffusion-for-text setups and can support efficient samplers. (arXiv)

2. The ternary-weight implementation is conceptually sound

The code keeps latent full-precision weights and uses STE ternary quantization in the forward pass. That is the standard kind of construction you would expect from BitNet-style training. The overall idea of native ternary weights is supported by BitNet b1.58. (arXiv)

3. The A8 → A4 schedule is the right instinct

This is one of the best choices in the repo. BitNet a4.8 is not “all 4-bit everywhere from the first step.” It is selective and hybrid. Your code is directionally aligned with that. (arXiv)

4. The block sampler has the right basic idea

Your block sampler uses committed context plus an ephemeral current block. That is a sensible prototype for block diffusion. The public block-diffusion work explicitly motivates arbitrary-length generation, KV caching, and parallel token sampling in exactly this general direction. (arXiv)

Must-fix before a 1B run

These are the items I would treat as hard blockers or near-blockers.

1. MaskDiffusionLoss can return NaN

This is the most important correctness bug I found.

The loss sets all non-supervised positions to ignore_index and then calls F.cross_entropy. If there are zero supervised positions in a batch, PyTorch returns NaN. I verified this locally with a tiny smoke test.

Why this matters:

  • with very long sequences, it is rare that no positions are masked,
  • but it is still a real edge case,
  • and thinking-token exclusion makes it easier for “all masked positions are excluded from loss” to happen on small or special batches.

Fix:

  • before calling cross_entropy, check if not mask_flat.any(): return logits.new_zeros(()).

Without this fix, rare NaNs can poison long training runs.

2. The variable-length curriculum is mostly canceled by the dataloader

Your data-prep script creates variable-length chunks. But StreamingJsonlDataset re-tokenizes the stored "text" and appends everything into one token buffer, then emits fixed max_length chunks.

So end to end, the effective training stream is mostly fixed-length re-chunked windows, not the intended weighted length distribution.

Why this matters:

  • your experiments become harder to interpret,
  • your training is less like the intended curriculum than you think,
  • if you believe shorter and longer contexts are both important, the current pipeline largely throws that away.

Fix:

  • store tokenized chunks directly, or
  • keep one JSONL line = one training example, do not flatten the entire corpus back into a global rolling token buffer.

3. attention_mask is built, then ignored

The collator returns attention_mask, but the training loop only uses batch["input_ids"]. The model forward path also has no attention-mask input.

Today this is partly hidden because the dataset mostly emits full-length chunks. But partial chunks still exist, and if you later restore true variable-length batching, this becomes a serious issue.

Why this matters:

  • padded positions can enter the corruption process,
  • padded positions can contribute to attention,
  • pad token is set to eos_token if missing, so the model can learn from EOS-padding artifacts.

Fix:

  • propagate attention_mask into masking and loss,
  • exclude padded positions from apply_mask,
  • exclude them from supervised loss,
  • ideally add real attention masking if you want genuine variable-length batches.

4. BlockDiffusionSampler.generate() is wrong for num_samples > 1

This is a real logic bug.

The block sampler accumulates all_generated and block_texts from block_ids[0], then uses that same shared buffer when returning results for all samples. So if num_samples > 1, the returned outputs are effectively copies of sample 0.

Fix:

  • keep all_generated per sample, not once globally,
  • keep block_texts per sample too.

If you only ever sample one output at a time, this does not hurt you. But it is still a bug.

5. generate_sample() can sample special tokens and silently turn them into the last vocabulary token

In the training monitor sampler, you sample from the full logit tensor, not just the normal vocabulary slice, and then at the end clamp IDs into [0, vocab_size - 1].

That means:

  • if the model samples the mask token or think token,
  • the code silently converts it to the last normal token ID.

This does not corrupt training directly. It corrupts your qualitative monitoring and makes samples less trustworthy.

Fix:

  • slice logits to :vocab_size before sampling, like your other samplers already do.

Should-fix

These are not guaranteed failures, but they weaken the experiment.

1. Thinking tokens are under-supervised

In code terms, think positions are excluded from the direct supervised loss. They only receive gradient indirectly through answer quality.

That can work as an experimental latent-variable trick. But it is weak supervision.

My expectation:

  • maybe helpful,
  • maybe ignored,
  • maybe unstable if you over-interpret it as “reasoning.”

For a first serious 1B run, I would either:

  • disable thinking tokens, or
  • keep only the simplest global-prefix version and remove per-block thinking.

2. Per-block thinking at inference does not match training

Training prepends one think prefix to the sequence.
Block sampling can prepend think tokens before every block.

That is a train-test mismatch.

It may still “work” in the loose sense that the model produces something. But if the feature matters at all, this mismatch makes the result harder to trust.

3. The KV quantization scheme is simpler than the strongest public guidance

Your active cache path uses a simple per-head absmax quantizer for both keys and values.

KIVI’s main conclusion is that keys and values do not want the same treatment: keys work better with per-channel quantization, values with per-token quantization. So your cache may still work, but it is not using the best-supported asymmetry yet. (arXiv)

4. The full-sequence denoiser’s KV cache buys almost nothing

In the non-block sampler, you reset the KV cache every denoising step and re-run the full sequence. That is logically correct because the mask pattern changes every step, but it also means the cache is not giving you a real inference win there.

That is not a bug. It just means:

  • KV cache matters mainly for your block sampler,
  • not for the full denoiser.

5. The default run is not actually 40B tokens

The training config computes to about 30.1B tokens, not 40B.

That is not a correctness problem. It is a planning problem. If you want a 40B-token run, your step count needs to change.

Fine for now

These are not where I would spend time first.

1. No causal mask in attention

Correct for diffusion.

2. Latent full-precision weights with STE

Standard for this kind of research implementation.

3. MoE code

Not the current concern because it is off by default.

4. RoPE offset handling in block mode

Directionally correct and useful for committed-context generation.

What I think will happen if you run it unchanged

Most likely:

  • it does train,
  • it gives coherent outputs,
  • the ternary core is not the main reason it fails,
  • the A8 → A4 schedule probably helps rather than hurts,
  • but the final result is harder to interpret because the data pipeline and thinking-token behavior are not clean.

The most likely disappointments are:

  • weaker-than-expected gains from low-bit activations,
  • unclear value from thinking tokens,
  • KV-cache quality below what you would hope from the best papers,
  • and results that are noisier than they need to be because of the data path and edge-case loss behavior. Recent work on 4-bit attention explicitly says attention is the main obstacle because of heavy-tailed activations and precision-mismatch instability, which matches where I would expect your run to be most fragile. (arXiv)

What I think will happen if you fix the blockers

Then I think the code has a real chance to produce a meaningful 1B prototype.

Not “state of the art.”
Not “obviously better than a same-budget AR baseline.”
But a real prototype that demonstrates:

  • masked diffusion training,
  • ternary-weight viability,
  • staged hybrid activation quantization,
  • and blockwise semi-autoregressive generation.

That is a legitimate target. MDLM supports the masked-diffusion backbone, block diffusion supports the blockwise generation idea, BitNet b1.58 supports native ternary weights, and BitNet a4.8 supports the general hybrid A8/A4 direction plus low-bit KV as an inference concept. (arXiv)

My recommendation

Before spending serious compute on 1B, I would do exactly this:

  1. Fix MaskDiffusionLoss for the zero-supervised-token case.

  2. Fix multi-sample block generation.

  3. Fix the qualitative sampler so it cannot turn special tokens into fake normal tokens.

  4. Decide whether you want:

    • real variable-length training, then preserve it end to end and use masks properly, or
    • fixed-length training, then simplify the pipeline and stop pretending otherwise.
  5. Disable thinking tokens for the first real 1B run.

  6. Treat the current KV cache as a prototype cache, not a final serving recipe.

Bottom line

Strict code-review answer:

The codebase is structurally real and probably trainable.
It is not clean enough yet for an unquestioned 1B run.
The main risks are correctness and experiment-interpretation risks, not “the whole concept is impossible.”


Here is the ranked patch plan I would use for your repo.

The ordering is based on one question only: what most reduces the chance of wasting a 1B / 40B-token run. The literature says your backbone choice is plausible: MDLM-style masked diffusion is a real language-modeling family, block diffusion is a real semi-autoregressive extension with KV reuse, ternary-from-scratch has precedent, and BitNet a4.8 supports the general idea of staged hybrid low-bit activations. The main fragility zone remains low-bit attention/activation behavior, not the existence of the overall concept. (arXiv)

Tier 0: patch before any serious 1B run

1. Make MaskDiffusionLoss safe when there are zero supervised positions

Files: bitdiffusion/diffusion.py

Why this is first

This is the only issue I found that can directly produce a silent training poison. I locally verified that your current loss returns NaN when every position is ignored.

In your code, the loss:

  • flattens logits and targets,
  • masks out non-supervised positions,
  • writes ignore_index into all other targets,
  • then calls F.cross_entropy(...).

If all positions are ignored, PyTorch returns NaN.

Patch

Add a guard right before cross_entropy:

if not mask_flat.any():
    return logits.new_zeros(())

Why it matters for your concept

Diffusion training already has noisier supervision than plain next-token prediction because the supervised set changes each batch. Block diffusion adds more schedule sensitivity, and low-bit training leaves less numerical slack. A rare NaN is much more dangerous in this regime than in a boring baseline. The block-diffusion paper explicitly highlights variance control and noise scheduling as first-class engineering concerns, and Attn-QAT shows that low-bit attention is already the main stability bottleneck. (arXiv)

Minimal test

  • unit test with is_masked = torch.zeros(...)
  • assert loss is finite and exactly zero

2. Fix multi-sample block generation

Files: bitdiffusion/sample.py

What is wrong

In BlockwiseDiffusionSampler.generate(), all_generated and block_texts are single shared Python lists, but the method returns one result per sample. The code collects tokens from block_ids[0] only, then reuses that same accumulated sequence for every sample.

So num_samples > 1 is currently wrong.

Patch

Change:

  • all_generated: list[int] = []
  • block_texts: list[str] = []

to per-sample structures, for example:

all_generated = [[] for _ in range(num_samples)]
block_texts = [[] for _ in range(num_samples)]

Then collect and decode per sample.

Why it matters

This does not break single-sample runs. But it makes batched sampling misleading, which is bad for evaluating diversity and sampler correctness. Since MDLM and block diffusion are often judged partly on generation behavior, broken multi-sample output makes the model look more deterministic or cleaner than it really is. (arXiv)

Minimal test

  • run num_samples=2 with a fixed seed and temperature > 0
  • assert outputs are independently tracked
  • assert internal block text lists differ when token traces differ

3. Fix generate_sample() so it cannot sample special tokens and silently map them to normal tokens

Files: bitdiffusion/train.py

What is wrong

Your qualitative monitor sampler samples from the full output vocabulary, then later clamps token IDs to vocab_size - 1. If the model samples the mask token or think token, that special token gets silently turned into the last normal vocabulary token.

So your training samples can look cleaner or stranger for the wrong reason.

Patch

Change:

probs = torch.softmax(logits / temperature, dim=-1)

to:

probs = torch.softmax(logits[:, :, :model.config.vocab_size] / temperature, dim=-1)

Do not rely on post-hoc clamping.

Why it matters

This does not directly affect training, but it absolutely affects whether you trust your monitoring. In diffusion models, qualitative inspection is important because loss curves alone do not tell the whole story about generation quality. (arXiv)

Minimal test

  • force logits to favor mask token
  • assert sampler never returns an out-of-range or silently remapped normal token

4. Decide whether you want true variable-length training or fixed-length training, then make the code match

Files: prepare_hf_jsonl.py, bitdiffusion/data.py, bitdiffusion/train.py

What is wrong

Your prep script creates a variable-length curriculum. Then the dataset loader re-tokenizes each "text" field, concatenates everything into a rolling token buffer, and emits fixed max_length chunks. So the end-to-end training stream is mostly fixed-length again.

Patch choice A: keep variable-length training

  • store tokenized examples directly
  • keep one JSONL example = one training example
  • use attention_mask throughout masking and loss
  • do not re-flatten the corpus into a global rolling token buffer

Patch choice B: admit fixed-length training

  • simplify prep
  • stop generating variable-length chunks upstream
  • keep fixed-length windows deliberately

My recommendation

For your first 1B run, choose B unless variable-length behavior is central to your research question. Fixed-length training is simpler and easier to debug.

Why it matters

Block diffusion papers emphasize variance and schedule quality. If your intended curriculum is being erased by the loader, you do not really know what you trained. Clean experimental semantics matter more here than fancy preprocessing. (arXiv)

Minimal test

  • inspect a batch length histogram after collation
  • confirm it matches what you think the loader is doing

Tier 1: fix before spending the full 40B tokens

5. Propagate attention_mask into corruption and loss, or remove padding entirely

Files: bitdiffusion/data.py, bitdiffusion/train.py, bitdiffusion/diffusion.py, bitdiffusion/model.py

What is wrong

The collator builds attention_mask. The training loop ignores it. The model forward path also ignores it.

Right now this is partly masked by your fixed-length behavior. But the moment you preserve variable lengths, padded positions become real positions for:

  • masking,
  • attention,
  • loss bookkeeping.

And because pad defaults to EOS if the tokenizer lacks a pad token, the model can learn EOS-padding artifacts.

Patch

At minimum:

  • exclude padded positions from apply_mask
  • exclude padded positions from MaskDiffusionLoss

If you later restore true variable-length batching:

  • also pass an attention mask into attention

Why it matters

This is less urgent than the NaN fix because your current loader mostly emits full chunks. But once you want honest variable-length behavior, this becomes a correctness issue, not a cleanup. (arXiv)


6. Disable thinking tokens for the first serious 1B baseline

Files: bitdiffusion/diffusion.py, bitdiffusion/train.py, bitdiffusion/sample.py

Why

This is the weakest-supervised subsystem in the code.

The code explicitly excludes thinking positions from direct supervised loss and expects them to become useful only through downstream answer gradients. That is possible in principle, but it is a weak signal. Also, training prepends one think prefix to the whole sequence, while the block sampler can prepend think tokens before every block. That is a train-test mismatch.

Patch

For the baseline 1B run:

  • set N_think = 0
  • set think_prob = 0
  • keep the code, but remove it from the main experiment

Then add it back only after the baseline works.

Why it matters

Your core concept does not need thinking tokens to be valid. MDLM, block diffusion, ternary weights, and hybrid A8/A4 already make a complete research story. Thinking tokens add ambiguity without adding much confidence. (arXiv)


7. Keep the current KV cache labeled as a prototype, and do not overfit conclusions to it

Files: bitdiffusion/quantization.py, bitdiffusion/sample.py

What is happening

Your active cache path uses a simple per-head absmax scheme for both keys and values. That is fine for a prototype, but it is simpler than the best-supported KV-cache quantization approaches.

KIVI’s main result is that keys and values want different treatment: keys per-channel, values per-token. Your current path does not do that. (arXiv)

Patch

Do one of these:

  • leave the current cache as-is, but call it a prototype cache and benchmark it honestly
  • or implement asymmetric K/V quantization closer to KIVI

My recommendation

For the first 1B run, keep it simple and prototype-level. Do not burn time rewriting the cache before the base model is proven.

Why it matters

KV cache is mostly an inference feature in your code, not a training feature. So this is not a blocker for pretraining. It is a blocker for making strong claims about deployment efficiency or quality retention. KIVI shows the asymmetry matters. (arXiv)


8. Add one explicit ablation checkpoint before the A8 → A4 switch

Files: bitdiffusion/train.py

Patch

Save:

  • one checkpoint right before the activation-mode switch
  • one checkpoint shortly after entering A4 mode

Also log:

  • masked-token accuracy
  • answer-only loss
  • fraction of masked positions per batch
  • gradient norm
  • activation mode

Why it matters

Current low-bit attention work says 4-bit attention is the main obstacle because of heavy-tailed activations and precision mismatch. If your run degrades, you want to know whether the break started:

  • before A4,
  • exactly at A4,
  • or long after. (arXiv)

Tier 2: worth fixing, but not before the first scaled run

9. Separate the “full denoiser” and “block sampler” evaluation stories

Files: bitdiffusion/sample.py

Why

Your full denoiser resets the KV cache every denoising step, which is logically correct because the full masked pattern changes every step. That means KV cache does not buy much there. The real cache benefit is in the block sampler.

Patch

Report them separately:

  • full diffusion sampling quality
  • blockwise generation quality and speed
  • KV cache effect only inside the blockwise path

Why it matters

It makes your conclusions cleaner and more aligned with what block diffusion is actually buying. (arXiv)


10. Add smoke tests for the exact failure cases above

Files: tests/

Add these tests

  • MaskDiffusionLoss zero-supervision returns finite zero
  • generate_sample() never samples special IDs into normal tokens
  • BlockwiseDiffusionSampler.generate(num_samples>1) returns independent per-sample outputs
  • data loader preserves intended length behavior
  • one tiny forward/backward pass on CPU works

Why it matters

Your code is already close enough to useful that small regressions matter. At 1B scale, simple tests are much cheaper than one wasted launch.


Tier 3: optional improvements after the baseline works

11. If you want better KV behavior, move toward asymmetric quantization

This is where I would spend time after the base 1B model works. KIVI gives a strong hint that asymmetry between keys and values is not cosmetic. (arXiv)

12. If you want stronger A4 confidence, add more attention-specific diagnostics

Attn-QAT makes it very clear that the hard part is not generic quantization. It is attention numerics. That suggests logging:

  • attention score range
  • softmax entropy
  • per-head activation max
  • fraction of saturated quantized values during A4 mode (arXiv)

13. If you want thinking tokens back, add a real training signal

Do this only after the plain model works. Right now they are more of a research hypothesis than a dependable subsystem.


The patch order I would actually execute

This week

  1. Fix MaskDiffusionLoss NaN case.
  2. Fix block sampler multi-sample bug.
  3. Fix generate_sample() vocabulary slicing.
  4. Decide fixed-length vs variable-length training and simplify accordingly.

Before the 1B launch

  1. Disable thinking tokens for baseline.
  2. Add mask-aware loss/corruption if you keep any variable-length batching.
  3. Add checkpoints around the A4 transition.
  4. Add the small tests.

After the baseline run

  1. Improve KV asymmetry.
  2. Add attention-specific A4 diagnostics.
  3. Reintroduce thinking tokens only as an ablation.

How I think it works after this patch plan

If you apply Tier 0 and Tier 1, I think the model has a real chance to do what you want in the limited sense that matters first:

  • train a 1B masked-diffusion model,
  • keep the ternary-weight core,
  • switch into hybrid low-bit activation late,
  • produce coherent blockwise-generated text,
  • and give you a trustworthy baseline for later KV and thinking-token experiments.

That expectation is supported by the literature around MDLM, block diffusion, BitNet b1.58, and BitNet a4.8. (arXiv)

If you do not apply the top patches, I think the likely failure mode is not dramatic collapse. It is worse: an expensive run that “sort of works,” but leaves you unsure whether the weak points came from diffusion, low-bit attention, your data path, or the under-supervised thinking-token mechanism. Recent 4-bit attention results are exactly why that distinction matters. (arXiv)