from transformers import BartConfig, BartForSequenceClassification, BartModel from torch import nn class BartForClassification(BartForSequenceClassification): def __init__(self, config: BartConfig): self.config = config self.bart = BartForSequenceClassification(config) self.bart.out_proj = nn.Linear(768, 4)