160 lines
3.9 KiB
Plaintext
160 lines
3.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch \n",
|
|
"import pandas\n",
|
|
"import numpy\n",
|
|
"from torch.autograd import Variable"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LinearRegressionModel(torch.nn.Module): \n",
|
|
" def __init__(self): \n",
|
|
" super(LinearRegressionModel, self).__init__() \n",
|
|
" self.linear = torch.nn.Linear(1, 1)\n",
|
|
" def forward(self, x): \n",
|
|
" y_pred = self.linear(x) \n",
|
|
" return y_pred "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_data=pandas.read_csv('C:/Users/eryk6/PycharmProjects/mieszkania5/train/train.tsv',sep='\\t',header=None)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x=train_data[0].tolist()\n",
|
|
"x = [str(train_data).replace(' ', '') for train_data in x]\n",
|
|
"\n",
|
|
"y=train_data[8].tolist()\n",
|
|
"y = [str(train_data).replace(' ', '') for train_data in y]\n",
|
|
"\n",
|
|
"x=numpy.array(x, dtype=numpy.float32)\n",
|
|
"y=numpy.array(y, dtype=numpy.float32)\n",
|
|
"\n",
|
|
"x = x.reshape(-1, 1)\n",
|
|
"y = y.reshape(-1, 1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = LinearRegressionModel()\n",
|
|
"criterion = torch.nn.MSELoss() \n",
|
|
"optimizer = torch.optim.SGD(our_model.parameters(), lr = 0.000000000001 ) "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for i in range(500): \n",
|
|
" input = Variable(torch.from_numpy(x))\n",
|
|
" pred_y = model(input)\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss = criterion(pred_y, Variable(torch.from_numpy(y))) \n",
|
|
" optimizer.zero_grad() \n",
|
|
" loss.backward() \n",
|
|
" optimizer.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test_A_in = pandas.read_csv('C:/Users/eryk6/PycharmProjects/mieszkania5/test-A/in.tsv',sep='\\t',header=None)\n",
|
|
"x = test_A_in[7].tolist()\n",
|
|
"x = numpy.array(x, dtype=numpy.float32)\n",
|
|
"x = x.reshape(-1, 1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y = model(Variable(torch.from_numpy(x))).data.numpy()\n",
|
|
"out = open('C:/Users/eryk6/PycharmProjects/mieszkania5/test-A/out.tsv', 'w')\n",
|
|
"for i in y:\n",
|
|
" out.write(str(i[0])+'\\n')\n",
|
|
"out.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dv_in = pandas.read_csv('C:/Users/eryk6/PycharmProjects/mieszkania5/dev-0/in.tsv',sep='\\t',header=None)\n",
|
|
"x = dv_in[7].tolist()\n",
|
|
"x = numpy.array(x, dtype=numpy.float32)\n",
|
|
"x = x.reshape(-1, 1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y = model(Variable(torch.from_numpy(x))).data.numpy()\n",
|
|
"output = open('C:/Users/eryk6/PycharmProjects/mieszkania5/dev-0/out.tsv', 'w')\n",
|
|
"for i in y:\n",
|
|
" output.write(str(i[0])+'\\n')\n",
|
|
"output.close()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.6 (tensorflow)",
|
|
"language": "python",
|
|
"name": "tensorflow"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|