UGP/bart.py

11 lines
348 B
Python

from transformers import BartConfig, BartForSequenceClassification, BartModel
from torch import nn
class BartForClassification(BartForSequenceClassification):
def __init__(self, config: BartConfig):
self.config = config
self.bart = BartForSequenceClassification(config)
self.bart.out_proj = nn.Linear(768, 4)