roberta base with no year emb
This commit is contained in:
parent
b1f955fa88
commit
d99a1c4e96
13172
dev-0/out.tsv
13172
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -16,14 +16,14 @@ test_tokenized_datasets = test_dataset.map(tokenize_function, batched=True)
|
|||||||
|
|
||||||
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
|
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
|
||||||
eval_dataset_full = tokenized_datasets["test"]
|
eval_dataset_full = tokenized_datasets["test"]
|
||||||
eval_dataset = tokenized_datasets["test"].select(range(2000))
|
eval_dataset_small = tokenized_datasets["test"].select(range(2000))
|
||||||
test_dataset = test_tokenized_datasets["train"]
|
test_dataset = test_tokenized_datasets["train"]
|
||||||
|
|
||||||
with open('train_dataset.pickle','wb') as f_p:
|
with open('train_dataset.pickle','wb') as f_p:
|
||||||
pickle.dump(train_dataset, f_p)
|
pickle.dump(train_dataset, f_p)
|
||||||
|
|
||||||
with open('eval_dataset.pickle','wb') as f_p:
|
with open('eval_dataset_small.pickle','wb') as f_p:
|
||||||
pickle.dump(eval_dataset, f_p)
|
pickle.dump(eval_dataset_small, f_p)
|
||||||
|
|
||||||
with open('eval_dataset_full.pickle','wb') as f_p:
|
with open('eval_dataset_full.pickle','wb') as f_p:
|
||||||
pickle.dump(eval_dataset_full, f_p)
|
pickle.dump(eval_dataset_full, f_p)
|
||||||
|
@ -4,8 +4,8 @@ from config import LABELS_LIST, MODEL
|
|||||||
with open('train_dataset.pickle','rb') as f_p:
|
with open('train_dataset.pickle','rb') as f_p:
|
||||||
train_dataset = pickle.load(f_p)
|
train_dataset = pickle.load(f_p)
|
||||||
|
|
||||||
with open('eval_dataset.pickle','rb') as f_p:
|
with open('eval_dataset_small.pickle','rb') as f_p:
|
||||||
eval_dataset = pickle.load(f_p)
|
eval_dataset_small = pickle.load(f_p)
|
||||||
|
|
||||||
with open('eval_dataset_full.pickle','rb') as f_p:
|
with open('eval_dataset_full.pickle','rb') as f_p:
|
||||||
eval_dataset_full = pickle.load(f_p)
|
eval_dataset_full = pickle.load(f_p)
|
||||||
@ -25,10 +25,15 @@ training_args = TrainingArguments("test_trainer",
|
|||||||
per_device_train_batch_size=4,
|
per_device_train_batch_size=4,
|
||||||
per_device_eval_batch_size=4,
|
per_device_eval_batch_size=4,
|
||||||
evaluation_strategy='steps',
|
evaluation_strategy='steps',
|
||||||
eval_steps=2_000,
|
#eval_steps=2_000,
|
||||||
gradient_accumulation_steps=10,
|
#save_steps=2_000,
|
||||||
|
eval_steps=20_000,
|
||||||
|
save_steps=20_000,
|
||||||
|
num_train_epochs=1,
|
||||||
|
gradient_accumulation_steps=2,
|
||||||
learning_rate = 1e-6,
|
learning_rate = 1e-6,
|
||||||
warmup_steps=4_000,
|
#warmup_steps=4_000,
|
||||||
|
warmup_steps=4,
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -49,10 +54,11 @@ trainer = Trainer(
|
|||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset_small,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#trainer.train(resume_from_checkpoint=True)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
trainer.save_model("./roberta-retrained")
|
trainer.save_model("./roberta-retrained")
|
||||||
trainer.evaluate()
|
trainer.evaluate()
|
||||||
|
@ -8,7 +8,7 @@ device = 'cpu'
|
|||||||
|
|
||||||
from transformers import AutoModelForSequenceClassification
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained('test_trainer/checkpoint-82000/')
|
model = AutoModelForSequenceClassification.from_pretrained('test_trainer/checkpoint-80/')
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
||||||
|
|
||||||
for dataset in ('dev-0', 'test-A'):
|
for dataset in ('dev-0', 'test-A'):
|
||||||
|
13310
test-A/out.tsv
13310
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user