Answering my own question: Subclass Trainer and override the compute_loss method (see example here).
Trainer
compute_loss