This is pretty similar to how most classification models work anyway.
Most embedding models, and base models for classification like BERT share the same bidirectional transformer architecture.
With an embedding you use the hidden state (or with a final downproj) on a single token or all tokens averaged.
With BERT and similar you have to add the classification head yourself that downprojects from hidden state.
There you also have the advantage of training the whole model, but close enough concept.
