how do you optimize training TweetyNet for automated annotation?
Hi @riifaaa, this is a good question. Thank you for asking.
Sorry I haven’t replied sooner–I had a deadline at my day job.
The short answer is: it depends. Which I know isn’t very helpful. I will say as much I can about each of the things you tried to optimize, and about interpreting the metrics, but please just keep in mind that I’m basing this mainly on what we know from the datasets in the TweetyNet paper, and what we’ve heard from other people using TweetyNet. Most of those datasets consist of many samples of vocalizations from single individuals that are acoustically isolated, like one male songbird in a soundbox recorded over several days. IIUC, you have field recordings? So the statistics of your data might be somewhat different. (Also @kaiyaprovost’s experience with data from Xeno-Canto might be helpful here.) I’m guessing you’ve done this already but just plotting descriptive stats of your data might be helpful (e.g., how many samples are there for each class of label, what is the distribution of segment times, etc.). I’m just suggesting that because I keep learning the hard way how important it is to inspect my data, even when I think I already have a good feel for it. If you can share more detail about your data, either here or by email with me if you prefer (nicholdav at gmail), then I might be able to give you more targeted advice.
With those caveats out of the way, I’ll try to speak to each of the things you mention.
- Improving the training dataset (increasing the number of recordings per song type group)
- Splitting the
train_dur
, val_dur
, and test_dur
into 50%, 25%, and 25% of my training dataset
For 1, the more data the better, as we know.
For 2, the problem of course is that you only have so much data. In a perfect world you would have enough data that the validation and the test set give you unbiased estimates of how the model performs on a new draw from the underlying distribution. But we don’t live in a perfect world, so we sacrifice some of our training data to validate that it’s training and to test it, and this means we might end up with a “pessimistic bias”.
In the paper we found that for Bengalese finches, ~10 minutes of training data was about the smallest amount we could use to get the best performing model–anything more than that was diminishing returns. We used 80s for the validation set, and 400s for the test set. This is something like 55% train:8% validation:37% test. Don’t read too much into that, but take it as a rough ballpark. I would guess you can reduce the amount you’re using for validation significantly–so put it in the training set!
To give you a better idea of how much training data you need, you can generate a learning curve like we did in the TweetyNet paper. Please find a configuration file here: vak/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml at main · vocalpy/vak · GitHub.
This is basically like running a bunch of vak train
files, where in each file you vary the size of the training set, and the actual samples that are in that training set, while holding the validation and test set constant. You do this by specifying the options train_set_durs
and num_replicates
in the [PREP]
section of the config file. The train_set_durs
option is an array of the training set durations – these will become the x axis of your learning curve. The num_replicates
is the number of training replicates, that is, the number of models trained “from scratch” (randomly initialized weights) for each training set duration. Each of these training sets are actually randomly drawn subsets of the total pool of training data, whose size you specify with the train_dur
option. To get an initial sense of whether you have enough training data, I would run a learning curve where your current training set duration is close to the maximum duration you use, and where you have a couple more smaller durations. Say, if your training set size now is 100 seconds, you could do [10, 20, 100, 200] (or something like that). You can use a small number of replicates (say, 3) so you get a quick first answer. When you plot the metrics from the curve you’ll see whether you are close to zero error or not. I am realizing I don’t have a good concise snippet handy of code to plot – plotting code in this notebook is very tailored to specific results but should give you some idea; I can reply to myself later with a better snippet.
- Increasing the
batch_size
parameter (here I am limited by my laptop spec - I tried 10
, 15
, and 20
)
- Increasing the
min_segment_dur
to 0.05
(targeted syllables are around 0.1
- 0.3
seconds.
I would be surprised if small batch size is limiting you in any way. Subjectively, we noticed in initial experiments slightly lower performance for larger batch sizes that are closer to what people typically use in computer vision, so we used smaller batch sizes, e.g. 8 for Bengalese finch song. There is some idea that smaller batch sizes can actually help optimization (but note that for them, “small”=“32-512”!). We haven’t tested the effect of batch size extensively though.
We have noticed that the post-processing steps do have a big impact on performance, perhaps bigger than anything else we have tried. See for example figure 5 of the TweetyNet paper. See also results in this recent poster–you will notice that the lowest error is always after post-processing with clean ups, and that this reduction is significantly larger than what we can squeeze out of the model by changing hyperparameters.
Are you also using majority_vote = True
in your post_tfm_kwargs
? You will definitely want to do this, especially when you generate predictions.
I was wondering what else can I do to improve the training step.
The other thing I might try, besides increasing the training data, would be to increase (1) the window size and (2) the hidden size of the recurrent network. If anything, I would make the window size as large as you can, even if that means you use a smaller batch size (like, 4!). Bigger window size actually helps, that was the punch line of that poster. (But again, post-processing seems to have a bigger impact.)
As someone who is only just learning about ML, I am not familiar with the terms val_acc
or val_levenshtein
, etc.
I’ll explain here but see also the “Metrics” section of the methods in the Tweetynet paper (annoyingly, eLife won’t give me a direct link to that section )
val_acc
is the per-frame accuracy on the validation set, where a “frame” is a time bin in the spectrogram. I.e., how many frames out of all the frames had the correct label. You can of course convert this to “frame error” (see Ch.5 of Graves’ dissertation) by subtracting it from 1.0, which is helpful when you want to compare with edit distance / segment error (since both metrics should then be going in the same direction, and it makes graphs easier to read, I mean).
val_levenshtein
is the Levenshtein distance on the validation set. It would probably be better if we called this the edit distance, since that’s clearer and avoids reifying any particular person. Intuitively you can think of this as “how many edits do I have to make to the predictions, to get back to the ground truth labels?”. More edits = a bigger distance.
val_segment_error
is a normalized version of the edit distance, that is, we divide the distance by the total number of labels in the reference string. We want this normalized metric so we can compare across sequences of different lengths. Here again we could probably have picked a better name since it sounds too much the names given to metrics of the segmentation itself like “segmental edit distance”. (In the paper we say “syllable error rate” but I don’t like that name either. Naming things is hard.)
Note that all these metrics are computed on a per-sample basis and then averaged. So basically they are telling you something like, “on average, each sample in my dataset will have n incorrect labels at a frame / segment level” – where a “sample” is one element in a batch, i.e. a spectrogram generated from one audio file.
I know that’s a very long-winded answer! Just trying to give you as much good info as I can. Happy to answer follow up questions too, hope it helps!