solution v2
This commit is contained in:
parent
3d5789c3bd
commit
a138439024
151
.gitignore
vendored
151
.gitignore
vendored
@ -1,149 +1,4 @@
|
||||
# ---> JupyterNotebooks
|
||||
# gitignore template for Jupyter Notebooks
|
||||
# website: http://jupyter.org/
|
||||
|
||||
.ipynb_checkpoints
|
||||
*/.ipynb_checkpoints/*
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# Remove previous ipynb_checkpoints
|
||||
# git rm -r .ipynb_checkpoints/
|
||||
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
.ipynb_checkpoints/
|
||||
model/
|
||||
geval
|
||||
|
||||
processed_train.txt
|
||||
|
63
run.ipynb
63
run.ipynb
@ -93,24 +93,11 @@
|
||||
"execution_count": 6,
|
||||
"id": "7a51f2b1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'SimpleBigramNeuralLanguageModel' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/tmp/ipykernel_16179/3272155308.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cuda'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtrain_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBigrams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'europarl.txt'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSimpleTrigramNeuralLanguageModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membed_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/tmp/ipykernel_16179/1892442743.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, vocabulary_size, embedding_size)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mSimpleTrigramNeuralLanguageModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocabulary_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membedding_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mSimpleBigramNeuralLanguageModel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEmbedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocabulary_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membedding_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membedding_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocabulary_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'SimpleBigramNeuralLanguageModel' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"\n",
|
||||
"device = 'cuda'\n",
|
||||
"device = 'cpu'\n",
|
||||
"train_dataset = Bigrams('europarl.txt', vocab_size)\n",
|
||||
"model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
||||
"data = DataLoader(train_dataset, batch_size=2000)\n",
|
||||
@ -120,11 +107,41 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"id": "474194ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 tensor(10.0424, grad_fn=<NllLossBackward0>)\n",
|
||||
"100 tensor(7.9016, grad_fn=<NllLossBackward0>)\n",
|
||||
"200 tensor(7.1964, grad_fn=<NllLossBackward0>)\n",
|
||||
"300 tensor(6.5661, grad_fn=<NllLossBackward0>)\n",
|
||||
"400 tensor(6.4146, grad_fn=<NllLossBackward0>)\n",
|
||||
"500 tensor(5.8718, grad_fn=<NllLossBackward0>)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m<ipython-input-8-06724a2e87a4>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 9\u001b[1;33m \u001b[0moutputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 10\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mstep\u001b[0m \u001b[1;33m%\u001b[0m \u001b[1;36m100\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1108\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1109\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1110\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1111\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1112\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m<ipython-input-2-4f6f391f0eb8>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 13\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1108\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1109\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1110\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1111\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1112\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 103\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 104\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 105\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"step = 0\n",
|
||||
"\n",
|
||||
"for epoch in range(1):\n",
|
||||
" model.train()\n",
|
||||
" for x, y in data:\n",
|
||||
@ -141,11 +158,19 @@
|
||||
" \n",
|
||||
"torch.save(model.state_dict(), 'model/model1.bin')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "alpha-leonard",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -159,7 +184,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.9.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Loading…
Reference in New Issue
Block a user