This commit is contained in:
Piotr Nawrot 2023-03-16 14:59:33 +01:00
commit 73b2e63aac
22 changed files with 2243 additions and 0 deletions

203
LICENSE Normal file
View File

@ -0,0 +1,203 @@
Copyright 2022 - Piotr Nawrot. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

188
README.md Normal file
View File

@ -0,0 +1,188 @@
# nanoT5 (Encoder-Decoder / Pre-training + Fine-Tuning)
![nanoT5](assets/nanoT5.png)
[**TLDR**](#tldr) | [**Motivation**](#motivation) | [**Setup**](#setup) | [**References**](#references) | [**Conclusions**](#conclusions) | [**Issues**](#issues)
## TLDR:
This repository contains the code to reproduce the pre-training of a "Large Language Model" (T5) under a limited budget (1xA100 GPU, ~20 hours) in PyTorch. We start from the randomly initialised T5-base-v1.1 (248M parameters) model implemented in HuggingFace. Next, we pre-train it on the English subset of the C4 dataset and then fine-tune it on Super-Natural Instructions (SNI).
**In ~20 hours on a single GPU, we achieve ~40 RougeL on the SNI test set, compared to ~42 RougeL of the original model available on HuggingFace Hub and pretrained through "a combination of model and data parallelism [...] on slices of Cloud TPU Pods", each with 1024 TPUs.**
Our core contribution is not the T5 model itself, which follows the HuggingFace implementation. Instead, we optimise everything else in the training pipeline to offer you a user-friendly starting template for your NLP application/research.
## Motivation
Despite the continuously increasing size of pretrained [Transformers](https://arxiv.org/pdf/1706.03762.pdf), the research community still needs easy-to-reproduce and up-to-date baselines to test new research hypotheses fast and at a small scale.
A recent effort from Andrej Karpathy, the [nanoGPT](https://github.com/karpathy/nanoGPT) repository, enables researchers to pre-train and fine-tune GPT-style (Decoder-only) language models. On the other hand, [Cramming](https://github.com/JonasGeiping/cramming) implements the optimal BERT-style (Encoder-only) pre-training for limited-compute settings.
With [nanoT5](https://github.com/PiotrNawrot/nanoT5), we want to fill a gap (Community requests: [#1](https://github.com/huggingface/transformers/issues/18030) [#2](https://github.com/facebookresearch/fairseq/issues/1899) [#3](https://github.com/google-research/text-to-text-transfer-transformer/issues/172) [#4](https://discuss.huggingface.co/t/example-of-how-to-pretrain-t5/4129) [#5](https://github.com/huggingface/transformers/issues/5079)) of an accessible research template to pre-train and fine-tune T5-style (Encoder-Decoder) model. **To the best of our knowledge, it is the first attempt to reproduce T5 v1.1 pre-training in PyTorch (previously available implementations are in Jax/Flax).**
##
**We created this repository for people who want to pre-train T5-style models by themselves and evaluate their performance on downstream tasks.** This could be for a variety of reasons:
- You are a researcher in academia with limited compute (like me), and you came up with a promising idea to modify the T5 model, so you need a pipeline to evaluate it;
- You have an in-house dataset that you think is more appropriate than the original pre-training data;
- You want to experiment with continued pre-training or want to build on the T5 pre-training objective.
**If you don't need to pre-train the T5 model, you'd be better off downloading the weights from HuggingFace Hub. Our checkpoints are worse because we work under limited compute.**
##
In this project, we expose (for research purposes) and optimise everything in the training pipeline of T5 except from model implementation. **Most importantly, we base our code on PyTorch, since access to TPUs is limited.** Among others:
- **Dataset:** Downloading and preprocessing of the C4 dataset happens in parallel with the training of the model. The C4 dataset is > 300GB, so it takes a couple of hours to download it and even longer to preprocess it. This codebase does it on the fly without any detrimental effect on the training loss (we haven't observed it, although it might happen with an old CPU (< 8 core) or a slow internet connection). **As a result, you can start pre-training right after downloading and setting up this repository.**
- **Model Optimizer / LR Scheduler:** The original T5 uses a memory-efficient Adafactor optimizer. [A study on pre-training T5](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models), on the other hand, reports that training does not converge with AdamW. We analysed the source of this discrepancy with several ablations. Although there are many subtle differences between Adafactor and AdamW, what ensures the Adafactor convergence is [matrix-wise LR scaling by its root mean square (RMS)](https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L595). We augmented the AdamW implementation by RMS scaling and observed that it becomes **more stable during pre-training, achieves better validation loss, and is faster**.
- **Exposure and simplicity:** We try to balance the implementation of the training pipeline by keeping it customisable while retaining a sufficient level of abstraction. We use the [HuggingFace Accelerator](https://huggingface.co/docs/accelerate/index) to implement operations like Checkpoint Saving, Gradient Accumulation and moving tensors to the correct devices. We use [neptune.ai](https://neptune.ai) for experiment tracking and [hydra](https://hydra.cc/docs/intro/) for hyperparameter search. Apart from this, we expose the training loop, data preprocessing, etc.
- **Efficiency:** We enable TF32 operations (Ampere GPUs) by default, use PyTorch 2.0 compile, and utilise all optimisations listed in established optimisation tutorials [#1](https://huggingface.co/docs/transformers/perf_train_gpu_one) [#2](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html).
## Setup
### Environment & Hardware:
```
git clone https://github.com/PiotrNawrot/nanoT5.git
cd nanoT5
conda create -n nanoT5 python=3.8
conda activate nanoT5
pip3 install numpy --pre torch torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu117
pip install -r requirements.txt
```
The following commands result in the following [pip freeze](assets/env_dump/pip_freeze.txt) as of 15.03.2023.
We also include our [lscpu](assets/env_dump/lscpu.txt) and [nvidia-smi](assets/env_dump/nvidia_smi.txt).
### Pre-training:
#### Reference:
The [T5 v1.1](https://arxiv.org/pdf/2002.05202.pdf) authors report **1.942** negative log-likelihood (NLL) on the held-out set after after 2^16 steps.
#### Legacy Optimizer (Adafactor) & LR Schedule (Inverse-Square-Root)
We follow the original experimental setup for pre-training, including [Dataset (C4)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L58), [Training Objective (Span Filling)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/copied_utils.py#L16), [Model Architecture (T5-Base)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/configs/default.yaml#L12), [Optimizer (Adafactor)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L236), and [LR Schedule (Inverse-Square-Root)](https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/model_utils.py#L276).
Our negative log-likelihood on the held-out set is **1.995**, slightly worse than the reference.
#### AdamW with RMS scaling Optimizer & Cosine LR Schedule
We also experiment with the AdamW optimizer (instead of the original Adafactor) as it offers more stability during training. Instead of using a low-rank approximation for the second moment of the gradients, it estimates it directly by storing the moving average for each parameter in memory. However, training diverges with AdamW, similar to [this study on T5 pre-training](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models). Through several ablations, we found that [matrix-wise LR scaling by its root mean square (RMS)](https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L595) is responsible for the convergence of Adafactor. We augmented the AdamW implementation by RMS scaling and observed that [it converges, becomes more stable during pre-training](assets/pt_loss.png) and is slightly faster (it retrieves the second moment from memory instead of approximating it via matrix multiplications).
However, AdamW, when paired with the Inverse-Square-Root LR schedule, performs worse than Adafactor. For our final experiment, we replace ISR with Cosine LR Schedule. We achieve **1.953** negative log-likelihood on the held-out set and significantly outperform Adafactor with ISR schedule.
<div align="center">
| | **Inverse-Square-Root** | **Cosine** |
| :---: | :----: | :---: |
| **Adafactor** | 1.995 | 1.993 |
| **AdamW** | 2.040 | **1.953** |
</div>
#### Increased BS (128 -> 144) to maximise GPU Utilization
We notice that with the original Batch Size of 128, we use 60GB / 80GB GPU memory. To maximise the GPU Utilization by allowing for more parallelism, we increase the Batch Size to 144 and consider it **our default pre-training config**. This achieves **1.932** negative log-likelihood on the held-out set, improving upon all previous experiments.
#### Training loss of experiments with different optimisers, schedulers, and batch sizes
![pt_loss](assets/pt_loss.png)
When not indicated in the plot, the batch size is 128.
#### Examples
To reproduce our default pre-training config experiment, run the following:
```
python -m nanoT5.main
```
To reproduce any of the experiments mentioned above choose any combination of hyperparameters as follows:
```
python -m nanoT5.main \
optim.name={adafactor,adamwscale} \
optim.batch_size={128,144} \
optim.lr_scheduler={legacy,cosine}
```
We recommend adding `model.compile=true` flag for pre-training, if you are able to install PyTorch 2.0. In our case it effects in 1.33x speedup.
Suppose you don't have access to a 80GB GPU. In that case, you can increase the number of gradient accumulation steps by `optim.grad_acc=steps`, In where `batch_size` has to be divisible by `steps`.
The summary of the optimization process is printed every 100 steps in the following format. For instance:
```
[train] Step 100 out of 65536 | Loss --> 59.881 | Grad_l2 --> 61.126 | Weights_l2 --> 7042.931 | Lr --> 0.010 | Seconds_per_step --> 1.385 |
```
### Fine-tuning:
To fine-tune our model, we use the popular meta-dataset called **Super Natural-Instructions (SNI)**, which aggregates datasets for many tasks. This meta-datasets was used to fine-tune many of the recent LLMs, e.g. [FlanT5](https://arxiv.org/pdf/2210.11416.pdf), [BLOOM](https://arxiv.org/pdf/2211.05100.pdf), and [Tk-Instruct](https://arxiv.org/pdf/2204.07705.pdf). While FlanT5 and BLOOM use other corpora in addition to SNI, Tk-Instruct's pipeline consists of starting from a pre-trained T5 model and fine-tuning it solely on SNI.
In this repository, we reproduce the Tk-Instruct fine-tuning results and use their pipeline to evaluate our pre-training config.
#### Download the Super-Natural Instructions data:
```
git clone https://github.com/allenai/natural-instructions.git data
```
#### Run fine-tuning:
We strictly follow the fine-tuning [config](nanoT5/configs/task/ft.yaml) of Tk-Instruct. It remains unclear whether Tk-Instruct was initialised from a regular checkpoint (*google/t5-v1_1-base*) or the one adapted explicitly for Language Modelling with continued training (*google/t5-base-lm-adapt*). Therefore, we decided to evaluate both. Run the following command to reproduce the Tk-Instruct experiments:
```
python -m adaptive.moe task=ft \
model.name={google/t5-v1_1-base,google/t5-base-lm-adapt} \
model.random_init={true,false} \
model.checkpoint_path={"","/path/to/pytorch_model.bin"}
```
Setting `model.random_init=false model.checkpoint_path=""` corresponds to downloading pre-trained weights from HuggingFace Hub.
Setting `model.random_init=false model.checkpoint_path="/path/to/pytorch_model.bin"` corresponds to using the weights [**pre-trained**](#pre-training) with nanoT5.
Setting `model.random_init=true model.checkpoint_path=""` corresponds to a random initialisation.
#### Fine-tuning loss curves:
![ft_loss](assets/ft_loss.png)
#### Rouge-L on the held-out test-set:
![ft_rougeL](assets/ft_rougeL.png)
### Efficiency statistics:
<div align="center">
| | **Pre-training** | **Fine-tuning** |
| :---: | :----: | :---: |
| **One training step** | ~1.05s | ~0.175s |
| **Steps** | 65536 | 18830 |
| **Full training** | ~19h | ~1h |
</div>
For pre-training we compile our model with PyTorch 2.0 using `model.compile=true` flag.
## Conclusions:
We show that it is possible to successfully pre-train a "Large Language Model" (T5) under a limited budget (1xA100 GPU, ~20 hours) in PyTorch. We make our codebase, configs and training logs publicly available to enhance the accessibility of NLP research. We are keen to hear your suggestions to improve the codebase further.
## References:
- [T5 paper](https://arxiv.org/pdf/1910.10683.pdf)
- [T5 v1.1 paper](https://arxiv.org/pdf/2002.05202.pdf)
- [Super-Natural Instructions paper](https://arxiv.org/pdf/2204.07705.pdf)
- [HuggingFace Flax Script](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py)
- [Karpathy's nanoGPT](https://github.com/karpathy/nanoGPT)
- [Instruct-GPT codebase (Super-Natural Instructions)](https://github.com/yizhongw/Tk-Instruct)
- [Blog about pre-training Dutch T5 in HuggingFace](https://huggingface.co/spaces/yhavinga/pre-training-dutch-t5-models)
## Issues:
If you have any questions, feel free to raise a Github issue or contact me directly at: piotr.nawrot@ed.ac.uk

30
assets/env_dump/lscpu.txt Normal file
View File

@ -0,0 +1,30 @@
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7763 64-Core Processor
Stepping: 1
CPU MHz: 2445.206
BogoMIPS: 4890.41
Virtualization: AMD-V
L1d cache: 32K
L1i cache: 32K
L2 cache: 512K
L3 cache: 32768K
NUMA node0 CPU(s): 0-15
NUMA node1 CPU(s): 16-31
NUMA node2 CPU(s): 32-47
NUMA node3 CPU(s): 48-63
NUMA node4 CPU(s): 64-79
NUMA node5 CPU(s): 80-95
NUMA node6 CPU(s): 96-111
NUMA node7 CPU(s): 112-127
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca

View File

@ -0,0 +1,20 @@
Tue Mar 14 23:23:03 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02 Driver Version: 510.85.02 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:01:00.0 Off | 0 |
| N/A 55C P0 282W / 500W | 31620MiB / 81920MiB | 89% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 963448 C python 31617MiB |
+-----------------------------------------------------------------------------+

View File

@ -0,0 +1,154 @@
absl-py==1.4.0
accelerate==0.17.1
aiohttp==3.8.4
aiosignal==1.3.1
antlr4-python3-runtime==4.9.3
anyio==3.6.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-timeout==4.0.2
attrs==22.2.0
backcall==0.2.0
beautifulsoup4==4.11.2
bleach==6.0.0
boto3==1.26.91
botocore==1.29.91
bravado==11.0.3
bravado-core==5.17.1
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
cmake==3.25.0
comm==0.1.2
datasets==2.10.1
debugpy==1.6.6
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
evaluate==0.4.0
executing==1.2.0
fancycompleter==0.9.1
fastjsonschema==2.16.3
filelock==3.9.0
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.3.0
future==0.18.3
gitdb==4.0.10
GitPython==3.1.31
huggingface-hub==0.13.2
hydra-core==1.3.2
idna==3.4
importlib-metadata==6.0.0
importlib-resources==5.12.0
ipykernel==6.21.3
ipython==8.11.0
ipython-genutils==0.2.0
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.2.0
jsonpointer==2.3
jsonref==1.1.0
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter_client==8.0.3
jupyter_core==5.2.0
jupyter_server==2.4.0
jupyter_server_terminals==0.4.4
jupyterlab-pygments==0.2.2
lit==15.0.7
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistune==2.0.5
monotonic==1.6
mpmath==1.2.1
msgpack==1.0.5
multidict==6.0.4
multiprocess==0.70.14
nbclassic==0.5.3
nbclient==0.7.2
nbconvert==7.2.10
nbformat==5.7.3
neptune==1.0.2
nest-asyncio==1.5.6
networkx==3.0rc1
nltk==3.8.1
notebook==6.5.3
notebook_shim==0.2.2
numpy==1.24.1
oauthlib==3.2.2
omegaconf==2.3.0
packaging==23.0
pandas==1.5.3
pandocfilters==1.5.0
parso==0.8.3
pdbpp==0.10.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.3.0
pkgutil_resolve_name==1.3.10
platformdirs==3.1.1
prometheus-client==0.16.0
prompt-toolkit==3.0.38
protobuf==3.20.3
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==11.0.0
pycparser==2.21
Pygments==2.14.0
PyJWT==2.6.0
pynvml==11.5.0
pyrepl==0.9.0
pyrsistent==0.19.3
python-dateutil==2.8.2
python-json-logger==2.0.7
pytorch-triton==2.1.0+2c32f43999
pytz==2022.7.1
PyYAML==6.0
pyzmq==25.0.1
regex==2022.10.31
requests==2.28.1
requests-oauthlib==1.3.1
responses==0.18.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rfc3987==1.3.8
rouge-score==0.1.2
s3transfer==0.6.0
Send2Trash==1.8.0
sentencepiece==0.1.97
simplejson==3.18.4
six==1.16.0
smmap==5.0.0
sniffio==1.3.0
soupsieve==2.4
stack-data==0.6.2
swagger-spec-validator==3.0.3
sympy==1.11.1
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.13.2
torch==2.1.0.dev20230315+cu117
torchaudio==2.0.0.dev20230313+cu117
torchvision==0.15.0.dev20230315+cu117
tornado==6.2
tqdm==4.65.0
traitlets==5.9.0
transformers==4.27.0
typing_extensions==4.4.0
uri-template==1.2.0
urllib3==1.26.13
wcwidth==0.2.6
webcolors==1.12
webencodings==0.5.1
websocket-client==1.5.1
wmctrl==0.4
xxhash==3.2.0
yarl==1.8.2
zipp==3.15.0

BIN
assets/ft_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

BIN
assets/ft_rougeL.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

BIN
assets/nanoT5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 406 KiB

BIN
assets/pt_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

0
nanoT5/__init__.py Normal file
View File

View File

@ -0,0 +1,59 @@
defaults:
- _self_
- task: pt
# Experiment args
mode: 'pt'
device: gpu
eval_only: false
predict_only: false
seed: 2137
model:
name: 'google/t5-v1_1-base'
checkpoint_path: ''
dropout: 0.0
random_init: true
compile: false # Pytorch 2.0
data:
input_length: 512
mlm_probability: 0.15
mean_noise_span_length: 3.0
num_workers: 8
optim:
name: adamwscale
base_lr: 2e-2
batch_size: 144
total_steps: 65536
epochs: -1 # If it's > 0 it overwrites total_steps
warmup_steps: 10000
lr_scheduler: cosine
weight_decay: 0.0
grad_clip: 1.0
grad_acc: 2
final_cosine: 1e-5
eval:
every_steps: 100000 # Don't eval
steps: 500
checkpoint:
every_steps: 30000
logging:
neptune: false
neptune_creds:
project:
api_token:
tags:
every_steps: 100
grad_l2: true
weights_l2: true
hydra:
job:
chdir: True
run:
dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}

View File

@ -0,0 +1,31 @@
# @package _global_
mode: 'ft'
data:
max_seq_len: 1024
max_target_len: 128
max_num_instances_per_task: 100
add_task_name: False
add_task_definition: True
num_pos_examples: 2
num_neg_examples: 0
add_explanation: False
tk_instruct: False
exec_file_path: ./nanoT5/utils/ni_dataset.py
data_dir: ./data/splits/default
task_dir: ./data/tasks
optim:
name: adamw
base_lr: 5e-5
batch_size: 8
epochs: 2
warmup_steps: 0
lr_scheduler: constant
weight_decay: 0.0
grad_clip: 0.0
grad_acc: 1
eval:
steps: 200

View File

@ -0,0 +1 @@
# @package _global_

69
nanoT5/main.py Normal file
View File

@ -0,0 +1,69 @@
from accelerate import Accelerator
from omegaconf import open_dict
import hydra
import torch
import time
from .utils import (
setup_basics,
train,
predict,
eval,
get_lr_scheduler,
get_optimizer,
get_tokenizer,
get_model,
get_dataloaders,
get_config,
)
@hydra.main(config_path="configs", config_name="default", version_base='1.1')
def main(args):
accelerator = Accelerator(cpu=args.device == "cpu")
logger = setup_basics(accelerator, args)
config = get_config(args)
model = get_model(args, config)
tokenizer = get_tokenizer(args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_lr_scheduler(optimizer, args, logger)
train_dataloader, test_dataloader = get_dataloaders(tokenizer, config, args)
logger.log_args(args)
(
model,
optimizer,
lr_scheduler,
train_dataloader,
test_dataloader,
) = accelerator.prepare(
model, optimizer, lr_scheduler, train_dataloader, test_dataloader
)
if args.model.compile:
model = torch.compile(model)
with open_dict(args):
args.current_train_step = 1
args.current_epoch = 1
args.last_log = time.time()
if args.eval_only:
model.eval()
with torch.no_grad():
eval(model, test_dataloader, logger, args, tokenizer)
elif args.predict_only:
model.eval()
with torch.no_grad():
predict(model, test_dataloader, logger,
args, tokenizer)
else:
train(model, train_dataloader, test_dataloader, accelerator,
lr_scheduler, optimizer, logger, args, tokenizer)
logger.finish()
if __name__ == "__main__":
main()

3
nanoT5/utils/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .gen_utils import *
from .model_utils import *
from .train_utils import *

View File

@ -0,0 +1,609 @@
from typing import Dict, List
import numpy as np
from transformers import BatchEncoding
from dataclasses import dataclass
from transformers import AutoTokenizer
import torch
import math
from torch.optim import Optimizer
from typing import Iterable, Tuple
from torch import nn
import random
import string
@dataclass
class DataCollatorForT5MLM:
"""
[Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py]
Data collator used for T5 span-masked language modeling.
It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
For more information on how T5 span-masked language modeling works, one can take a look
at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
noise_density (:obj:`float`):
The probability with which to (randomly) mask tokens in the input.
mean_noise_span_length (:obj:`float`):
The average span length of the masked tokens.
input_length (:obj:`int`):
The expected input length after masking.
target_length (:obj:`int`):
The expected target length after masking.
pad_token_id: (:obj:`int`):
The pad token id of the model
decoder_start_token_id: (:obj:`int):
The decoder start token id of the model
"""
tokenizer: AutoTokenizer
noise_density: float
mean_noise_span_length: float
input_length: int
target_length: int
pad_token_id: int
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
# convert list to dict and tensorize input
batch = BatchEncoding(
{
k: np.array([examples[i][k] for i in range(len(examples))])
for k, v in examples[0].items()
}
)
input_ids = batch["input_ids"]
batch_size, expandend_input_length = input_ids.shape
mask_indices = np.asarray(
[
self.random_spans_noise_mask(expandend_input_length)
for i in range(batch_size)
]
)
labels_mask = ~mask_indices
input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
if batch["input_ids"].shape[-1] != self.input_length:
raise ValueError(
f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
f" should be {self.input_length}."
)
if batch["labels"].shape[-1] != self.target_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
f" {self.target_length}."
)
batch = {k: torch.from_numpy(v) for k, v in batch.items()}
return batch
def create_sentinel_ids(self, mask_indices):
"""
Sentinel ids creation given the indices that should be masked.
The start indices of each mask are replaced by the sentinel ids in increasing
order. Consecutive mask indices to be deleted are replaced with `-1`.
"""
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
start_indices[:, 0] = mask_indices[:, 0]
sentinel_ids = np.where(
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
)
sentinel_ids = np.where(
sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0
)
sentinel_ids -= mask_indices - start_indices
return sentinel_ids
def filter_input_ids(self, input_ids, sentinel_ids):
"""
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
"""
batch_size = input_ids.shape[0]
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
# masked tokens coming after sentinel tokens and should be removed
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
input_ids = np.concatenate(
[
input_ids,
np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32),
],
axis=-1,
)
return input_ids
def random_spans_noise_mask(self, length):
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args:
length: an int32 scalar (length of the incoming token sequence)
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
a boolean tensor with shape [length]
"""
orig_length = length
num_noise_tokens = int(np.round(length * self.noise_density))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = max(num_noise_spans, 1)
num_nonnoise_tokens = length - num_noise_tokens
# pick the lengths of the noise spans and the non-noise spans
def _random_segmentation(num_items, num_segments):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segments: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segments] containing positive integers that add
up to num_items
"""
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]])
segment_id = np.cumsum(first_in_segment)
# count length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
nonnoise_span_lengths = _random_segmentation(
num_nonnoise_tokens, num_noise_spans
)
interleaved_span_lengths = np.reshape(
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
[num_noise_spans * 2],
)
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = np.zeros((length,), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
[Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py]
Training parameters to avoid padding with random_spans_noise_mask.
When training a model with random_spans_noise_mask, we would like to set the other
training hyperparmeters in a way that avoids padding.
This function helps us compute these hyperparameters.
We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
This function tells us the required number of tokens in the raw example (for split_tokens())
as well as the length of the encoded targets. Note that this function assumes
the inputs and targets will have EOS appended and includes that in the reported length.
Args:
inputs_length: an integer - desired length of the tokenized inputs sequence
noise_density: a float
mean_noise_span_length: a float
Returns:
tokens_length: length of original text in tokens
targets_length: an integer - length in tokens of encoded targets sequence
"""
def _tokens_length_to_inputs_length_targets_length(tokens_length):
num_noise_tokens = int(round(tokens_length * noise_density))
num_nonnoise_tokens = tokens_length - num_noise_tokens
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
# inputs contain all nonnoise tokens, sentinels for all noise spans
# and one EOS token.
_input_length = num_nonnoise_tokens + num_noise_spans + 1
_output_length = num_noise_tokens + num_noise_spans + 1
return _input_length, _output_length
tokens_length = inputs_length
while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
tokens_length += 1
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
# minor hack to get the targets length to be equal to inputs length
# which is more likely to have been set to a nice round number.
if noise_density == 0.5 and targets_length > inputs_length:
tokens_length -= 1
targets_length -= 1
return tokens_length, targets_length
class AdamWScale(Optimizer):
"""
This AdamW implementation is copied from Huggingface.
We modified it with Adagrad scaling by rms of a weight tensor
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 1e-3):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-6):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
super().__init__(params, defaults)
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
def step(self, closure=None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
state = self.state[p]
beta1, beta2 = group["betas"]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
# /Adapt Step from Adagrad
step_size = step_size * max(1e-3, self._rms(p.data))
# /Adapt Step from Adagrad
p.data.addcdiv_(exp_avg, denom, value=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"]))
return loss
def tokenize_function(examples, tokenizer, in_length):
tokenizer_out = tokenizer(
text=examples["text"],
return_attention_mask=False,
)
input_ids = tokenizer_out["input_ids"]
concatenated_ids = np.concatenate(input_ids)
total_length = concatenated_ids.shape[0]
total_length = (total_length // in_length) * in_length
concatenated_ids = concatenated_ids[:total_length].reshape(-1, in_length)
result = {"input_ids": concatenated_ids}
return result
from transformers.data.data_collator import *
@dataclass
class DataCollatorForNI:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_source_length: Optional[int] = None
max_target_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
add_task_name: bool = False
add_task_definition: bool = True
num_pos_examples: int = 0
num_neg_examples: int = 0
add_explanation: bool = False
tk_instruct: bool = False
text_only: bool = False
def __call__(self, batch, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
sources = []
for instance in batch:
if self.tk_instruct:
all_valid_encodings = [
# instruction only
{
"add_task_name": False,
"add_task_definition": True,
"num_pos_examples": 0,
"num_neg_examples": 0,
"add_explanation": False,
},
# example only
{
"add_task_name": False,
"add_task_definition": False,
"num_pos_examples": 2,
"num_neg_examples": 0,
"add_explanation": False,
},
# instruction + pos examples
{
"add_task_name": False,
"add_task_definition": True,
"num_pos_examples": 2,
"num_neg_examples": 0,
"add_explanation": False,
},
# instruction + pos examples + neg examples
{
"add_task_name": False,
"add_task_definition": True,
"num_pos_examples": 2,
"num_neg_examples": 2,
"add_explanation": False,
},
# instruction + pos (w. explanation)
{
"add_task_name": False,
"add_task_definition": True,
"num_pos_examples": 2,
"num_neg_examples": 0,
"add_explanation": True,
},
]
encoding_schema = random.choice(all_valid_encodings)
add_task_name = encoding_schema["add_task_name"]
add_task_definition = encoding_schema["add_task_definition"]
num_pos_examples = encoding_schema["num_pos_examples"]
num_neg_examples = encoding_schema["num_neg_examples"]
add_explanation = encoding_schema["add_explanation"]
else:
add_task_name = self.add_task_name
add_task_definition = self.add_task_definition
num_pos_examples = self.num_pos_examples
num_neg_examples = self.num_neg_examples
add_explanation = self.add_explanation
task_input = ""
# add the input first.
task_input += "Now complete the following example -\n"
task_input += f"Input: {instance['Instance']['input'].strip()}"
if not task_input[-1] in string.punctuation:
task_input += "."
task_input += "\n"
task_input += "Output: "
task_name = ""
if add_task_name:
task_name += instance["Task"] + ". "
definition = ""
if add_task_definition:
if isinstance(instance["Definition"], list):
definition = (
"Definition: " + instance["Definition"][0].strip()
)
else:
definition = "Definition: " + instance["Definition"].strip()
if not definition[-1] in string.punctuation:
definition += "."
definition += "\n\n"
# try to add positive examples.
pos_examples = []
for idx, pos_example in enumerate(
instance["Positive Examples"][:num_pos_examples]
):
pos_example_str = f" Positive Example {idx+1} -\n"
pos_example_str += f"Input: {pos_example['input'].strip()}"
if not pos_example_str[-1] in string.punctuation:
pos_example_str += "."
pos_example_str += "\n"
pos_example_str += f" Output: {pos_example['output'].strip()}"
if not pos_example_str[-1] in string.punctuation:
pos_example_str += "."
pos_example_str += "\n"
if add_explanation and "explanation" in pos_example:
pos_example_str += (
f" Explanation: {pos_example['explanation'].strip()}"
)
if not pos_example_str[-1] in string.punctuation:
pos_example_str += "."
pos_example_str += "\n"
pos_example_str += "\n"
if (
len(
self.tokenizer(
definition
+ " ".join(pos_examples)
+ pos_example_str
+ task_input
)["input_ids"]
)
<= self.max_source_length
):
pos_examples.append(pos_example_str)
else:
break
# try to add negative examples.
neg_examples = []
for idx, neg_example in enumerate(
instance["Negative Examples"][:num_neg_examples]
):
neg_example_str = f" Negative Example {idx+1} -\n"
neg_example_str += f"Input: {neg_example['input'].strip()}"
if not neg_example_str[-1] in string.punctuation:
neg_example_str += "."
neg_example_str += "\n"
neg_example_str += f" Output: {neg_example['output'].strip()}"
if not neg_example_str[-1] in string.punctuation:
neg_example_str += "."
neg_example_str += "\n"
if add_explanation and "explanation" in neg_example:
neg_example_str += (
f" Explanation: {neg_example['explanation'].strip()}"
)
if not neg_example_str[-1] in string.punctuation:
neg_example_str += "."
neg_example_str += "\n"
neg_example_str += "\n"
if (
len(
self.tokenizer(
definition
+ " ".join(pos_examples)
+ " ".join(neg_examples)
+ neg_example_str
+ task_input
)["input_ids"]
)
<= self.max_source_length
):
neg_examples.append(neg_example_str)
else:
break
source = (
task_name
+ definition
+ "".join(pos_examples)
+ "".join(neg_examples)
+ task_input
)
tokenized_source = self.tokenizer(source)["input_ids"]
if len(tokenized_source) <= self.max_source_length:
sources.append(source)
else:
sources.append(
self.tokenizer.decode(
tokenized_source[: self.max_source_length],
skip_special_tokens=True,
)
)
if self.text_only:
model_inputs = {"inputs": sources}
else:
model_inputs = self.tokenizer(
sources,
max_length=self.max_source_length,
padding=self.padding,
return_tensors=self.return_tensors,
truncation=True,
pad_to_multiple_of=self.pad_to_multiple_of,
)
if "output" in batch[0]["Instance"] and batch[0]["Instance"]["output"]:
# Randomly select one reference if multiple are provided.
labels = [random.choice(ex["Instance"]["output"]) for ex in batch]
if self.text_only:
model_inputs["labels"] = labels
else:
labels = self.tokenizer(
labels,
max_length=self.max_target_length,
padding=self.padding,
return_tensors=self.return_tensors,
truncation=True,
pad_to_multiple_of=self.pad_to_multiple_of,
)
label_mask = labels["attention_mask"].bool()
model_inputs["labels"] = labels["input_ids"].masked_fill(
~label_mask, self.label_pad_token_id
)
else:
model_inputs["labels"] = None
return model_inputs

61
nanoT5/utils/gen_utils.py Normal file
View File

@ -0,0 +1,61 @@
import torch
import os
from accelerate.utils import set_seed
from omegaconf import open_dict
from .logging_utils import Logger
from hydra.utils import to_absolute_path
def check_args_and_env(args):
assert args.optim.batch_size % args.optim.grad_acc == 0
# Train log must happen before eval log
assert args.eval.every_steps % args.logging.every_steps == 0
if args.device == 'gpu':
assert torch.cuda.is_available(), 'We use GPU to train/eval the model'
assert not (args.eval_only and args.predict_only)
if args.predict_only:
assert args.mode == 'ft'
def opti_flags(args):
# This lines reduce training step by 2.4x
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def update_args_with_env_info(args):
with open_dict(args):
slurm_id = os.getenv('SLURM_JOB_ID')
if slurm_id is not None:
args.slurm_id = slurm_id
else:
args.slurm_id = 'none'
args.working_dir = os.getcwd()
def update_paths(args):
if args.mode == 'ft':
args.data.exec_file_path = to_absolute_path(args.data.exec_file_path)
args.data.data_dir = to_absolute_path(args.data.data_dir)
args.data.task_dir = to_absolute_path(args.data.task_dir)
def setup_basics(accelerator, args):
check_args_and_env(args)
update_args_with_env_info(args)
update_paths(args)
opti_flags(args)
if args.seed is not None:
set_seed(args.seed)
logger = Logger(args=args, accelerator=accelerator)
return logger

View File

@ -0,0 +1,95 @@
from collections import defaultdict
from accelerate.logging import get_logger
from omegaconf import OmegaConf, open_dict
import logging
import datasets
import transformers
import neptune
import os
class Averager:
def __init__(self, weight: float = 1):
self.weight = weight
self.reset()
def reset(self):
self.total = defaultdict(float)
self.counter = defaultdict(float)
def update(self, stats):
for key, value in stats.items():
self.total[key] = self.total[key] * self.weight + value * self.weight
self.counter[key] = self.counter[key] * self.weight + self.weight
def average(self):
averaged_stats = {
key: tot / self.counter[key] for key, tot in self.total.items()
}
self.reset()
return averaged_stats
class Logger:
def __init__(self, args, accelerator):
self.logger = get_logger('Main')
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
self.logger.info(accelerator.state, main_process_only=False)
self.logger.info(f'Working directory is {os.getcwd()}')
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
self.setup_neptune(args)
def setup_neptune(self, args):
if args.logging.neptune:
neptune_logger = neptune.init_run(
project=args.logging.neptune_creds.project,
api_token=args.logging.neptune_creds.api_token,
tags=[str(item) for item in args.logging.neptune_creds.tags.split(",")],
)
else:
neptune_logger = None
self.neptune_logger = neptune_logger
with open_dict(args):
if neptune_logger is not None:
args.neptune_id = neptune_logger["sys/id"].fetch()
def log_args(self, args):
if self.neptune_logger is not None:
logging_args = OmegaConf.to_container(args, resolve=True)
self.neptune_logger['args'] = logging_args
def log_stats(self, stats, step, args, prefix=''):
if self.neptune_logger is not None:
for k, v in stats.items():
self.neptune_logger[f'{prefix}{k}'].log(v, step=step)
msg_start = f'[{prefix[:-1]}] Step {step} out of {args.optim.total_steps}' + ' | '
dict_msg = ' | '.join([f'{k.capitalize()} --> {v:.3f}' for k, v in stats.items()]) + ' | '
msg = msg_start + dict_msg
self.log_message(msg)
def log_message(self, msg):
self.logger.info(msg)
def finish(self):
if self.neptune_logger is not None:
self.neptune_logger.stop()

321
nanoT5/utils/model_utils.py Normal file
View File

@ -0,0 +1,321 @@
import torch
import datasets
from torch.utils.data import DataLoader
from omegaconf import open_dict
from datasets.iterable_dataset import IterableDataset
from transformers import (
AutoTokenizer,
T5ForConditionalGeneration,
AutoConfig,
)
from .copied_utils import (
compute_input_and_target_lengths,
DataCollatorForT5MLM,
tokenize_function,
DataCollatorForNI,
)
def get_model(args, config):
if args.model.checkpoint_path:
model = T5ForConditionalGeneration(
config,
)
model.load_state_dict(torch.load(args.model.checkpoint_path))
elif args.model.random_init:
model = T5ForConditionalGeneration(
config,
)
else:
model = T5ForConditionalGeneration.from_pretrained(
args.model.name,
config=config,
)
return model
def get_config(args):
config = AutoConfig.from_pretrained(
args.model.name,
)
config.dropout_rate = args.model.dropout
return config
def get_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.model.name,
use_fast=True
)
tokenizer.model_max_length = int(1e9)
return tokenizer
def load_dataset_splits(args):
if args.mode == 'pt':
dataset = datasets.load_dataset(
'c4',
'en',
streaming=True,
)
dataset = dataset.remove_columns(
['timestamp', 'url']
)
dataset_splits = {
'train': dataset['train'],
'test': dataset['validation'],
}
assert (
dataset['train'].n_shards == 1024
), "We want to have many shards for efficient processing with num_workes in PyTorch dataloader"
elif args.mode == 'ft':
dataset_splits = datasets.load_dataset(
args.data.exec_file_path,
data_dir=args.data.data_dir,
task_dir=args.data.task_dir,
max_num_instances_per_task=args.data.max_num_instances_per_task,
max_num_instances_per_eval_task=args.data.max_num_instances_per_task
)
else:
raise NotImplementedError
return dataset_splits
def process_dataset(dataset_splits, args, tokenizer):
if args.mode == 'pt':
final_datasets = {}
for split, dataset_split in dataset_splits.items():
# We increase the input_length, because instead of masking tokens T5 replaces
# masked spans with a single token, therefore to avoid padding we need to have
# longer sequences at the start, before masking
before_mask_input_length, target_length = compute_input_and_target_lengths(
inputs_length=args.data.input_length,
noise_density=args.data.mlm_probability,
mean_noise_span_length=args.data.mean_noise_span_length,
)
with open_dict(args):
args.data.before_mask_input_length = before_mask_input_length
args.data.target_length = target_length
dataset_split = dataset_split.map(
tokenize_function,
batched=True,
fn_kwargs={
'tokenizer': tokenizer,
'in_length': before_mask_input_length,
},
remove_columns=['text'],
)
dataset_split = dataset_split.shuffle(buffer_size=10_000, seed=args.seed)
final_datasets[split] = dataset_split
elif args.mode == 'ft':
final_datasets = dataset_splits
else:
raise NotImplementedError
return final_datasets
def get_data_collator(tokenizer, config, args):
if args.mode == 'pt':
data_collator = DataCollatorForT5MLM(
tokenizer=tokenizer,
noise_density=args.data.mlm_probability,
mean_noise_span_length=args.data.mean_noise_span_length,
input_length=args.data.input_length,
target_length=args.data.target_length,
pad_token_id=config.pad_token_id,
)
elif args.mode == 'ft':
data_collator = DataCollatorForNI(
tokenizer,
padding="longest",
max_source_length=args.data.max_seq_len,
max_target_length=args.data.max_target_len,
label_pad_token_id=-100,
pad_to_multiple_of=8,
add_task_name=args.data.add_task_name,
add_task_definition=args.data.add_task_definition,
num_pos_examples=args.data.num_pos_examples,
num_neg_examples=args.data.num_neg_examples,
add_explanation=args.data.add_explanation,
tk_instruct=args.data.tk_instruct
)
else:
raise NotImplementedError
return data_collator
def get_dataloaders(tokenizer, config, args):
dataset_splits = load_dataset_splits(args)
dataset = process_dataset(dataset_splits=dataset_splits, args=args, tokenizer=tokenizer)
data_collator = get_data_collator(tokenizer=tokenizer, config=config,
args=args)
is_iterable = isinstance(dataset['train'], IterableDataset)
dataloaders = {}
for split in ['train', 'test']:
batch_size = args.optim.batch_size // args.optim.grad_acc
if split in ['test']:
batch_size *= 2
shuffle = (split == 'train') and not is_iterable
if args.mode == 'ft' and split == 'train':
assert shuffle is True
else:
assert shuffle is False
dataloaders[split] = DataLoader(
dataset[split],
shuffle=shuffle,
collate_fn=data_collator,
batch_size=batch_size,
num_workers=args.data.num_workers,
pin_memory=True,
drop_last=False,
)
# Add & Check args about data loaders
with open_dict(args):
if not is_iterable:
args.data.train_batches = len(dataloaders['train'])
args.data.test_batches = len(dataloaders['test'])
if args.optim.epochs > 0:
assert not is_iterable
args.optim.total_steps = len(dataloaders['train']) * args.optim.epochs
# We increase eval BS by 2, so decrease number of eval steps
args.eval.corrected_steps = args.eval.steps / 2
return dataloaders['train'], dataloaders['test']
def get_optimizer(model, args):
no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.optim.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if args.optim.name == 'adamw':
from transformers import AdamW
optimizer = AdamW(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
)
elif args.optim.name == 'adamwscale':
from .copied_utils import AdamWScale
optimizer = AdamWScale(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
)
elif args.optim.name == 'adafactor':
from transformers import Adafactor
optimizer = Adafactor(
optimizer_grouped_parameters,
lr=args.optim.base_lr,
relative_step=False,
)
else:
raise NotImplementedError
return optimizer
def get_lr_scheduler(optimizer, args, logger):
if args.optim.lr_scheduler == 'cosine':
from torch.optim.lr_scheduler import (
SequentialLR,
LinearLR,
CosineAnnealingLR,
)
scheduler1 = LinearLR(
optimizer,
start_factor=0.5,
end_factor=1,
total_iters=args.optim.warmup_steps,
last_epoch=-1,
)
scheduler2 = CosineAnnealingLR(
optimizer,
T_max=args.optim.total_steps - args.optim.warmup_steps,
eta_min=args.optim.final_cosine,
)
lr_scheduler = SequentialLR(
optimizer,
schedulers=[scheduler1, scheduler2],
milestones=[args.optim.warmup_steps]
)
elif args.optim.lr_scheduler == 'legacy':
import math
from torch.optim.lr_scheduler import (
SequentialLR,
LinearLR,
LambdaLR,
)
msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr"
logger.log_message(msg)
num_steps_optimizer1 = math.ceil(args.optim.total_steps * 0.9)
iters_left_for_optimizer2 = args.optim.total_steps - num_steps_optimizer1
scheduler1 = LambdaLR(
optimizer,
lambda step: min(
1e-2, 1.0 / math.sqrt(step)
) / args.optim.base_lr if step else 1e-2 / args.optim.base_lr
)
scheduler2 = LinearLR(
optimizer,
start_factor=(
min(1e-2, 1.0 / math.sqrt(num_steps_optimizer1)) / args.optim.base_lr
),
end_factor=0,
total_iters=iters_left_for_optimizer2,
last_epoch=-1,
)
lr_scheduler = SequentialLR(
optimizer,
schedulers=[scheduler1, scheduler2],
milestones=[num_steps_optimizer1]
)
elif args.optim.lr_scheduler == 'constant':
from transformers import get_scheduler
lr_scheduler = get_scheduler(
name=args.optim.lr_scheduler,
optimizer=optimizer,
)
else:
raise NotImplementedError
return lr_scheduler

173
nanoT5/utils/ni_dataset.py Normal file
View File

@ -0,0 +1,173 @@
# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Natural Instruction V2 Dataset."""
import json
import os
import random
import datasets
logger = datasets.logging.get_logger(__name__)
_CITATION = """
@article{wang2022benchmarking,
title={Benchmarking Generalization via In-Context Instructions on 1,600+ Language Tasks},
author={Wang, Yizhong and Mishra, Swaroop and Alipoormolabashi, Pegah and Kordi, Yeganeh and others},
journal={arXiv preprint arXiv:2204.07705},
year={2022}
}
"""
_DESCRIPTION = """
Natural-Instructions v2 is a benchmark of 1,600+ diverse language tasks and their expert-written instructions.
It covers 70+ distinct task types, such as tagging, in-filling, and rewriting.
These tasks are collected with contributions of NLP practitioners in the community and
through an iterative peer review process to ensure their quality.
"""
_URL = "https://instructions.apps.allenai.org/"
class NIConfig(datasets.BuilderConfig):
def __init__(self, *args, task_dir=None, max_num_instances_per_task=None, max_num_instances_per_eval_task=None, **kwargs):
super().__init__(*args, **kwargs)
self.task_dir: str = task_dir
self.max_num_instances_per_task: int = max_num_instances_per_task
self.max_num_instances_per_eval_task: int = max_num_instances_per_eval_task
class NaturalInstructions(datasets.GeneratorBasedBuilder):
"""NaturalInstructions Dataset."""
VERSION = datasets.Version("2.0.0")
BUILDER_CONFIG_CLASS = NIConfig
BUILDER_CONFIGS = [
NIConfig(name="default", description="Default config for NaturalInstructions")
]
DEFAULT_CONFIG_NAME = "default"
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"Task": datasets.Value("string"),
"Contributors": datasets.Value("string"),
"Source": [datasets.Value("string")],
"URL": [datasets.Value("string")],
"Categories": [datasets.Value("string")],
"Reasoning": [datasets.Value("string")],
"Definition": [datasets.Value("string")],
"Positive Examples": [{
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"explanation": datasets.Value("string")
}],
"Negative Examples": [{
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"explanation": datasets.Value("string")
}],
"Input_language": [datasets.Value("string")],
"Output_language": [datasets.Value("string")],
"Instruction_language": [datasets.Value("string")],
"Domains": [datasets.Value("string")],
# "Instances": [{
# "input": datasets.Value("string"),
# "output": [datasets.Value("string")]
# }],
"Instance": {
"id": datasets.Value("string"),
"input": datasets.Value("string"),
"output": [datasets.Value("string")]
},
"Instance License": [datasets.Value("string")]
}
),
supervised_keys=None,
homepage="https://github.com/allenai/natural-instructions",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
if self.config.data_dir is None or self.config.task_dir is None:
dl_path = dl_manager.download_and_extract(_URL)
self.config.data_dir = self.config.data_dir or os.path.join(dl_path, "splits")
self.config.task_dir = self.config.task_dir or os.path.join(dl_path, "tasks")
split_dir = self.config.data_dir
task_dir = self.config.task_dir
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"path": os.path.join(split_dir, "train_tasks.txt"),
"task_dir": task_dir,
"max_num_instances_per_task": self.config.max_num_instances_per_task,
"subset": "train"
}),
# datasets.SplitGenerator(
# name=datasets.Split.VALIDATION,
# gen_kwargs={
# "path": os.path.join(split_dir, "dev_tasks.txt"),
# "task_dir": task_dir,
# "max_num_instances_per_task": self.config.max_num_instances_per_eval_task,
# "subset": "dev"
# }),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"path": os.path.join(split_dir, "test_tasks.txt"),
"task_dir": task_dir,
"max_num_instances_per_task": self.config.max_num_instances_per_eval_task,
"subset": "test"
}),
]
def _generate_examples(self, path=None, task_dir=None, max_num_instances_per_task=None, subset=None):
"""Yields examples."""
logger.info(f"Generating tasks from = {path}")
with open(path, encoding="utf-8") as split_f:
for line in split_f:
task_name = line.strip()
task_path = os.path.join(task_dir, task_name + ".json")
with open(task_path, encoding="utf-8") as task_f:
s = task_f.read()
task_data = json.loads(s)
task_data["Task"] = task_name
if "Instruction Source" in task_data:
task_data.pop("Instruction Source")
all_instances = task_data.pop("Instances")
if subset == "test":
# for testing tasks, 100 instances are selected for efficient evaluation and they are label-balanced.
# we put them in the first for reproducibility.
# so, we use them here
instances = all_instances[:100]
else:
instances = all_instances
if max_num_instances_per_task is not None and max_num_instances_per_task >= 0:
random.shuffle(instances)
instances = instances[:max_num_instances_per_task]
for idx, instance in enumerate(instances):
example = task_data.copy()
example["id"] = instance["id"]
example["Instance"] = instance
yield f"{task_name}_{idx}", example

211
nanoT5/utils/train_utils.py Normal file
View File

@ -0,0 +1,211 @@
import torch
import time
import evaluate
from .logging_utils import Averager
from datasets.iterable_dataset import IterableDataset
def maybe_save_checkpoint(accelerator, args):
if (
args.current_train_step > args.optim.total_steps
or args.current_train_step % args.checkpoint.every_steps == 0
):
output_dir = f'checkpoint-{args.mode}-{args.current_train_step}'
accelerator.save_state(output_dir=output_dir)
def maybe_eval_predict(model, dataloader, logger, args, tokenizer):
if (
args.current_train_step > args.optim.total_steps
or args.current_train_step % args.eval.every_steps == 0
):
model.eval()
with torch.no_grad():
eval(model, dataloader, logger, args, tokenizer)
if args.mode == 'ft':
predict(
model, dataloader, logger, args, tokenizer
)
args.last_log = time.time()
model.train()
def maybe_logging(averager, args, model, optimizer, logger):
if args.current_train_step % args.logging.every_steps == 0:
stats = extra_stats(args, model, optimizer)
seconds_per_step = (time.time() - args.last_log) / args.logging.every_steps
stats['seconds_per_step'] = seconds_per_step
averager.update(stats)
averaged_stats = averager.average()
logger.log_stats(
stats=averaged_stats,
step=args.current_train_step,
args=args,
prefix='train/'
)
args.last_log = time.time()
def maybe_grad_clip_and_grad_calc(accelerator, model, args):
if args.logging.grad_l2:
grad_l2 = (
sum(p.grad.detach().data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
)
else:
grad_l2 = None
if args.optim.grad_clip > 0:
accelerator.clip_grad_norm_(
parameters=model.parameters(),
max_norm=args.optim.grad_clip,
norm_type=2,
)
if grad_l2 is not None:
return {'grad_l2': grad_l2}
else:
return {}
def extra_stats(args, model, optimizer):
stats = {}
if args.logging.weights_l2:
weights_l2 = sum(p.detach().norm(2).item() ** 2 for p in model.parameters()) ** 0.5
stats['weights_l2'] = weights_l2
cur_lr = optimizer.param_groups[0]['lr']
stats['lr'] = cur_lr
return stats
def forward(model, batch, calc_acc=False):
outputs = model(**batch)
loss = outputs.loss
stats = {}
stats['loss'] = loss.detach().float().item()
if calc_acc:
correct = (outputs.logits.argmax(-1) == batch["labels"]).sum().item()
accuracy = correct / batch["labels"].numel()
stats['accuracy'] = accuracy
return loss, stats
def eval(model, dataloader, logger, args, tokenizer):
args.last_log = time.time()
averager = Averager()
for batch_id, batch in enumerate(dataloader, start=1):
if batch_id == args.eval.corrected_steps * args.optim.grad_acc:
break
_, stats = forward(model, batch, calc_acc=True)
averager.update(stats)
averager.update({'time': time.time() - args.last_log})
averaged_stats = averager.average()
logger.log_stats(
stats=averaged_stats,
step=args.current_train_step,
args=args,
prefix='eval/'
)
def predict(model, dataloader, logger, args, tokenizer):
args.last_log = time.time()
metric = evaluate.load('rouge')
samples_seen = 0
def decode(preds):
preds[preds == -100] = tokenizer.pad_token_id
preds = tokenizer.batch_decode(
preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
preds = [pred.strip() for pred in preds]
return preds
for step, batch in enumerate(dataloader):
predictions = model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
max_length=args.data.max_target_len,
generation_config=model.generation_config,
)
predictions = decode(predictions)
references = decode(batch["labels"])
# If we are in a multiprocess environment, the last batch has duplicates
if step == len(dataloader) - 1:
predictions = predictions[: len(dataloader.dataset) - samples_seen]
references = references[: len(dataloader.dataset) - samples_seen]
else:
samples_seen += len(references)
metric.add_batch(
predictions=predictions,
references=references,
)
eval_metric = metric.compute(use_stemmer=True, use_aggregator=False)
rougeL = sum(eval_metric["rougeL"]) * 100 / len(eval_metric["rougeL"])
logger.log_stats(
stats={
"rougeL": rougeL,
"time": time.time() - args.last_log,
},
step=args.current_train_step,
args=args,
prefix="test/",
)
def train(model, train_dataloader, test_dataloader, accelerator, lr_scheduler,
optimizer, logger, args, tokenizer):
model.train()
train_averager = Averager()
while args.current_train_step <= args.optim.total_steps:
if isinstance(train_dataloader.dataset, IterableDataset):
train_dataloader.dataset.set_epoch(args.current_epoch)
for batch_id, batch in enumerate(train_dataloader, start=1):
if args.current_train_step > args.optim.total_steps:
break
loss, stats = forward(model, batch)
accelerator.backward(loss / args.optim.grad_acc)
train_averager.update(stats)
if batch_id % args.optim.grad_acc == 0:
stats = maybe_grad_clip_and_grad_calc(accelerator, model, args)
train_averager.update(stats)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
maybe_logging(train_averager, args, model, optimizer, logger)
maybe_eval_predict(model, test_dataloader, logger, args, tokenizer)
maybe_save_checkpoint(accelerator, args)
args.current_train_step += 1
args.current_epoch += 1
maybe_eval_predict(model, test_dataloader, logger, args, tokenizer)
maybe_save_checkpoint(accelerator, args)

15
requirements.txt Normal file
View File

@ -0,0 +1,15 @@
accelerate
datasets >= 1.8.0
sentencepiece != 0.1.92
transformers
neptune
pdbpp
notebook
protobuf==3.20.*
pyyaml
pynvml
hydra-core
evaluate
nltk
absl-py
rouge_score