Batch Predictions
How to make batch predictions in fastai
Making batch predictions on new data is not provided “out of the box” in fastai. This is how you can achieve that:
Add this method to learner:
@patch
def predict_batch(self:Learner, item, rm_type_tfms=None, with_input=False):
= self.dls.test_dl(item, rm_type_tfms=rm_type_tfms, num_workers=0)
dl = self.get_preds(dl=dl, with_input=True, with_decoded=True)
inp,preds,_,dec_preds = getattr(self.dls, 'n_inp', -1)
i = (inp,) if i==1 else tuplify(inp)
inp = zip(*self.dls.decode_batch(inp + tuplify(dec_preds)))
dec_inp, nm = preds,nm,dec_preds
res if with_input: res = (dec_inp,) + res
return res
You can then use this method like so:
>>> from fastai.text.all import *
>>> from predict_batch import predict_batch # this file. If you don't import just define in your script.
>>> dls = TextDataLoaders.from_folder(untar_data(URLs.IMDB), valid='test')
>>> learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)
>>> learn.fine_tune(4, 1e-2)
>>> learn.predict_batch(["hello world"]*4)
0.0029, 0.9971],
(TensorText([[0.0029, 0.9971],
[0.0029, 0.9971],
[0.0029, 0.9971]]),
['pos', 'pos', 'pos', 'pos'),
(1, 1, 1, 1])) TensorText([
Alternatively, you can just patch the predict function so it works on batches:
@patch
def predict(self:Learner, item, rm_type_tfms=None, with_input=False):
= self.dls.test_dl(item, rm_type_tfms=rm_type_tfms, num_workers=0)
dl = self.get_preds(dl=dl, with_input=True, with_decoded=True)
inp,preds,_,dec_preds = getattr(self.dls, 'n_inp', -1)
i = (inp,) if i==1 else tuplify(inp)
inp = self.dls.decode_batch(inp + tuplify(dec_preds))
dec = (tuple(map(detuplify, d)) for d in zip(*dec.map(lambda x: (x[:i], x[i:]))))
dec_inp,dec_targ = dec_targ,dec_preds,preds
res if with_input: res = (dec_inp,) + res
return res
Other notes h/t zach:
learn.dls.vocab
or learn.dls.categorize.vocab
is another way to get the class names.