11 lines
348 B
Python
11 lines
348 B
Python
from transformers import BartConfig, BartForSequenceClassification, BartModel
|
|
from torch import nn
|
|
|
|
class BartForClassification(BartForSequenceClassification):
|
|
def __init__(self, config: BartConfig):
|
|
self.config = config
|
|
self.bart = BartForSequenceClassification(config)
|
|
self.bart.out_proj = nn.Linear(768, 4)
|
|
|
|
|