Finetune whisper-tiny in german for tflite runtime

Helly everyone,

I’ve been trying to finetune a customized whisper-tiny model in german which I want to serve with the tflite runtime in a rust application for fast inference on mobile devices.
Loading the stock model from huggingface (openai/whisper-tiny at main) and converting it to tflite is working fine, but I haven’t found any solution to finetune the model in tensorflow format so that I get the model weights as .h5 file in the end which I can then load with TFWhisperForConditionalGeneration.from_pretrained() for the tflite conversion.

All finetuning tutorials I’ve found so far are pytorch based, and converting the model from pytorch over onnx to tensorflow format has resulted in asaved_model.pb file which I can’t load with TFWhisperForConditionalGeneration.from_pretrained() .

HuggingFace provides the .h5 format for the whisper-tiny model in above link, so I assume there must be a way to either finetune the model or convert it from pytorch to tensorflow format but I haven’t found a solution so far.

Does anyone have a hint to solve this issue?

My current pytorch-to-tf conversion script looks as follows:

import torch
import onnx
import librosa
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import onnx_tf
import numpy as np

np.bool = np.bool_

path = "openai/whisper-tiny"
# only model and processor need to be kept during training
processor = WhisperProcessor.from_pretrained(path)

model = WhisperForConditionalGeneration.from_pretrained(path)

# laod examplary input
y, sr = librosa.load("audio/en.wav", sr=16_000)
assert sr==16_000, "Only 16k sr supported"
input_features = processor(y, sampling_rate=sr, return_tensors="pt").input_features 
decoder_input_ids = torch.tensor([[50258]])

res = model(input_features, None, decoder_input_ids)
# print(res.logits.shape)
torch.onnx.export(model,
                  (input_features, None, decoder_input_ids),
                  'whisper.onnx',
                  input_names=['input_features',                         
                               'attention_mask',
                               'decoder_input_ids'],
                  output_names=['output'], 
                  opset_version=14)

onnx_model = onnx.load('whisper.onnx')

# print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])

tf_model = onnx_tf.backend.prepare(onnx_model)
tf_model.export_graph("whisper.tf")

requirements:

python 3.10.13
tensorflow 2.15.0
torch 2.1.4
transformers 4.45.2

Scripts for HF staff to use for their own use may be available on github. And in many cases, there is no manual. The scripts below are just examples; explore github.

Thanks, will definitely take a look at it:)