11 lines
348 B
Python
11 lines
348 B
Python
|
from transformers import BartConfig, BartForSequenceClassification, BartModel
|
||
|
from torch import nn
|
||
|
|
||
|
class BartForClassification(BartForSequenceClassification):
|
||
|
def __init__(self, config: BartConfig):
|
||
|
self.config = config
|
||
|
self.bart = BartForSequenceClassification(config)
|
||
|
self.bart.out_proj = nn.Linear(768, 4)
|
||
|
|
||
|
|