UGP/bart.py

11 lines
348 B
Python
Raw Normal View History

2023-02-16 18:21:17 +01:00
from transformers import BartConfig, BartForSequenceClassification, BartModel
from torch import nn
class BartForClassification(BartForSequenceClassification):
def __init__(self, config: BartConfig):
self.config = config
self.bart = BartForSequenceClassification(config)
self.bart.out_proj = nn.Linear(768, 4)