pytorch使用gpu

安装环境

一般nvidia与cuda是安装好的,没好在管网查。

再就是安装anaconda,里面添加一个py38的环境吧。

最后安装pytorch不能像其他组件一样在pycharm里安装,得在这个页面,选相应版本安装。https://pytorch.org/get-started/locally/

比如我的。先conda activate py38

再用页面的pip命令安装。

应该选conda也行。


运行结果:



一块gpu。

代码:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import torch.nn.functional as F
from torchmetrics import R2Score
from torchmetrics import MeanAbsoluteError
from torchmetrics import MeanSquaredError

import torch
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

print("cuda is available:"+str(torch.cuda.is_available()))
print("gpu:"+str(torch.cuda.device_count()))
# 检查是否有 GPU 支持
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(matplotlib.matplotlib_fname())
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题



#data_csv = pd.read_csv('data/rnn.spider.csv', usecols=[2])
data_csv = pd.read_csv('data/rnn_spider_9.csv', usecols=[2])
plt.plot(data_csv)
#plt.show()

# 数据预处理
data_csv = data_csv.dropna() #݄去掉  na
dataset = data_csv.values
olddataset = dataset.astype('float32')
max_value = np.max(dataset)
min_value = np.min(dataset)
scalar = max_value - min_value

dataset = list(map(lambda x: (x-min_value) / scalar, olddataset))
#dataset = list(map(lambda x: x / scalar, olddataset))
#X每两行形成一组,如[284,1425],[1425,2674],形成[0.1,0.5],
#Y从下标2开始每行形成一组,如[2674]形成[1.0]
def create_dataset(dataset, look_back=2):
    dataX, dataY = [], []
    for i in range(len(dataset) - look_back):
        a = dataset[i:(i + look_back)]
        dataX.append(a)
        dataY.append(dataset[i + look_back])
    return np.array(dataX), np.array(dataY)



# 创建好输入输出
data_X, data_Y = create_dataset(dataset)


# 划分训练集和测试集,   70% 作为训练集
train_size = int(len(data_X) * 0.7)
test_size = len(data_X) - train_size
train_X = data_X[:train_size]
train_Y = data_Y[:train_size]
test_X = data_X[train_size:]
test_Y = data_Y[train_size:]

train_max_value = np.max(train_X)
train_min_value = np.min(train_X)
# 三维数组reshape函数中 第一个参数2代表深度,第二个参数3代表行,第三个参数4代表列。

train_X = train_X.reshape(-1, 1, 2)
train_Y = train_Y.reshape(-1, 1, 1)
test_X = test_X.reshape(-1, 1, 2)
train_x = torch.from_numpy(train_X).to(device)
train_y = torch.from_numpy(train_Y).to(device)
test_x = torch.from_numpy(test_X).to(device)


from torch import nn
from torch.autograd import Variable


#torch.device('cpu')
# 定义模型
class lstm_reg(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1, num_layers=2):
        super(lstm_reg, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers) # rnn
        self.reg = nn.Linear(hidden_size, output_size) #回归
    def forward(self, x):
        x, _ = self.rnn(x) # (seq, batch, hidden)
        s, b, h = x.shape
        x = x.view(s*b, h) # 转换成线性层的输入格式
        x = self.reg(x)
        x = x.view(s, b, -1)
        return x

#2行,4个隐藏层
net = lstm_reg(2, 4).to(device)
criterion = nn.MSELoss()
#学习率真为0.01
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

print(net)
# 开始训练
for e in range(1000):
    var_x = Variable(train_x)
    var_y = Variable(train_y)
    # 前向传播
    out = net(var_x)
    loss = criterion(out, var_y)
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (e + 1) % 100 == 0: #每  100 次输出结果
        print('Epoch: {}, Loss: {:.5f}'.format(e + 1, loss.data))

net = net.eval() # 转换成测试模式

data_X = data_X.reshape(-1, 1, 2)
data_X = torch.from_numpy(data_X).to(device)
var_data = Variable(data_X)
pred_test = net(var_data) #测试集的预测结果
if device!="cpu":
    pred_test=pred_test.cpu()
# 改变输出的格式
pred_test = pred_test.view(-1).data.numpy()



#print_pred_test=net(torch.tensor([[0.8,0.6]]).unsqueeze(0).float())
#print(print_pred_test)
#print(print_pred_test*(max_value-min_value))
#print(olddataset*(max_value-min_value))

# 画出实际结果和预测的结果
plt.plot(pred_test, 'r', label='预测值')
plt.plot(dataset, 'b', label='真实值')
plt.legend(loc='best')
plt.show()

# 计算R2指标
r2 = R2Score()

realData=[]
for sublist in dataset:
    # 再次使用for循环遍历子列表中的元素
    for element in sublist:
        realData.append(element)
r2(torch.as_tensor(np.array(pred_test)), torch.as_tensor(np.array(realData[:-2])))
# 输出R2指标值
print("R2 score:", r2.compute())


mae=MeanAbsoluteError()
mae(torch.as_tensor(np.array(pred_test)), torch.as_tensor(np.array(realData[:-2])))
print("mae score:", mae.compute())

mse=MeanSquaredError()
mse(torch.as_tensor(np.array(pred_test)), torch.as_tensor(np.array(realData[:-2])))
print("mse score:", mse.compute())

print("rmse score:", torch.sqrt(mse.compute()))
'''
mae = F.l1_loss( torch.tensor(pred_test),  torch.tensor(dataset))
print(mae)

mae = F.Rl1_loss( torch.tensor(pred_test),  torch.tensor(dataset))
'''



数据

rnn_spider_9.csv

id,fileNum,totalTime/fileNum
683,1409,0.5862
682,865,0.6069
681,474,0.6371
678,482,0.5830
677,1385,0.7487
676,142,0.6127
675,795,0.6226
674,156,0.5577
669,61,0.5738
668,1139,0.6084
667,1346,0.6293
666,365,0.5890
653,56,0.6964
652,567,0.6155
651,624,0.6122
650,468,0.5876
649,670,0.6463
648,538,0.6041
645,257,0.5642
644,928,0.5636
643,169,0.6509
642,1078,0.5946
641,289,0.5952
640,167,0.5988
638,427,0.5738
637,669,0.5800
636,488,0.6434
635,1099,0.6497
634,930,0.5634
633,1401,0.6089
632,319,0.6395
631,406,0.7266
630,91,0.6264
629,23,0.7391
628,86,0.7209
625,23,1.4348
624,934,0.9550
623,180,0.6333
622,642,0.6495
621,717,0.6388
620,246,0.5854
619,182,0.6538
618,66,0.6212
613,149,0.5973
612,417,0.5779
611,1228,0.6466
610,499,0.6453
609,712,0.6236
598,881,1.1975
597,1123,0.6438
596,932,0.6148
591,600,0.6233
590,1062,0.6507
589,714,0.6148
588,667,1.2309
583,182,1.4505
582,799,0.6671
581,499,0.6814
580,530,0.6377
579,848,0.7028
572,480,0.6667
571,1288,0.6382
570,808,0.6238
569,183,0.6393
568,838,0.6337
564,66,0.6061
563,669,0.6368
562,371,0.6792
561,68,0.5735
560,363,0.6556
559,968,0.6312
558,470,0.6447
557,672,0.6429
554,326,0.6472
553,178,0.6404
552,916,0.6354
551,155,0.7548
550,216,0.6250
549,632,0.6392
548,518,0.6622
541,33,0.6364
540,766,0.6397
539,99,0.6162
538,1413,0.6369
537,665,0.6436
527,661,0.6399
526,1496,0.6310
525,428,0.6285
524,403,0.6551
519,57,0.6667
518,1397,0.6464
517,259,0.6873
516,136,0.6838
515,35,0.6286
514,597,0.6683
513,27,0.7037
509,703,0.6358
508,112,0.6071
507,812,0.5924
506,1362,0.5675
505,544,0.5993
504,1347,0.6288
503,989,0.6643
499,886,0.6716
498,162,0.7716
497,1120,0.6625
493,613,0.7015
487,383,0.6240
486,1390,0.6417
485,119,0.6303
484,879,0.6667
483,298,0.6477
482,385,0.6494
480,92,0.6522
479,1193,0.6438
478,1075,0.6335
477,775,0.6284
476,489,0.6380
463,206,0.6359
462,1316,0.6071
461,247,0.5870
460,36,0.5556
459,342,0.6374
458,56,0.6786
457,497,0.6318
456,329,0.6383
446,847,0.6198
445,1184,0.6453
444,304,0.6875
443,327,0.5688
442,311,0.5145
436,158,0.5443
435,493,0.5659
434,1462,0.6005
433,944,0.6356
432,687,0.6317
431,427,0.6745
430,579,0.6269
429,32,0.6250
428,479,0.7035
422,632,0.6598
421,1384,0.6575
420,94,0.6489
419,777,0.6499
414,605,0.7025
413,1148,0.6516
412,435,0.6437
411,674,0.6899
407,732,0.6325
406,1456,0.6518
405,221,0.6968
404,545,0.6312
402,381,0.6220
401,1430,0.6245
400,338,0.6243
399,455,0.6813
398,355,0.6620
393,978,0.6196
392,561,0.6132
391,163,0.5951
390,893,0.5980
389,375,0.6453
388,90,0.5889
383,220,0.6682
382,689,0.6488
381,970,0.6227
380,1058,0.6172
376,788,0.6345
375,333,0.5796
374,52,0.6731
373,741,0.6478
372,357,0.6022
371,670,0.6373
365,851,0.6428
364,976,0.8145
363,852,0.6631
361,387,0.6098
360,274,0.6460
355,878,0.6378
354,825,0.6255
353,897,0.6410
352,407,0.6536
351,1193,0.6203
350,581,0.6368
349,433,0.6005
348,229,0.5764
344,261,0.6322
343,1076,0.6747
342,11,8.8182
341,1262,0.6458
340,892,0.6670
339,444,0.6486
338,810,0.6802
337,544,0.6507
336,393,0.6616
335,98,0.7347
330,1140,0.6623
329,1089,0.6804
328,266,0.6579
327,720,0.6306
326,320,0.6688
325,1291,0.6414
324,268,0.7127
323,485,0.6165
322,273,0.6374
316,193,0.5389
315,493,0.6085
314,918,0.6133
313,1224,0.6242
312,116,0.6293
310,823,0.7400
309,60,0.6500
308,167,0.7425
307,190,3.7895
306,141,0.7021
305,362,0.7182
304,93,0.6774
302,98,0.7755
301,333,0.7327
300,412,0.7112
299,382,0.7487
298,636,0.7327
297,933,0.7256
296,149,0.7517
292,1118,0.7504
291,1183,1.5097
290,344,0.7587
289,217,0.8203
279,620,0.7613
278,450,0.7156
277,1253,0.7007
276,722,0.6856
271,820,0.7085
270,967,0.7787
269,1120,0.7429
268,127,0.7480
261,388,0.7964
260,997,0.7593
259,960,0.8063
258,573,0.7243
256,179,0.6927
255,1472,0.7310
254,757,0.7596
253,503,0.7614
249,805,0.7565
248,739,0.7442
247,1323,0.7385
238,1121,0.7101
237,1066,0.7242
236,137,0.6788
235,625,0.7728
229,291,0.7320
228,109,0.7615
227,824,0.6917
226,1115,0.7031
225,650,0.7215
224,521,0.7179
223,1131,0.6932
222,562,0.6726
217,860,0.6756
216,1386,0.6710
215,800,0.6738
208,451,0.7095
207,651,0.6390
206,330,0.6394
205,391,0.6726
204,320,1.5750
203,800,0.6475
197,277,0.6354
196,305,0.6230
195,337,0.6202
194,558,1.1219
193,1008,0.6250
192,125,0.7200
188,566,0.6484
187,481,0.6258
186,335,0.6388
185,1231,0.6255
184,1027,0.6475
183,505,0.6277
182,754,0.6737
181,820,0.6134
180,325,0.6431
179,643,0.6174
173,293,0.6348
162,612,1.1225
161,504,1.1905
160,1554,1.1782
159,1661,1.0108
158,266,0.8609
157,1155,0.8970
156,337,0.9644
155,1681,0.9590
154,1827,0.8752
153,1257,0.8815
152,1452,0.8333
151,936,0.8665
150,1159,0.9094
149,1616,0.9462
148,581,0.8090
147,965,0.8435
146,671,1.0313
145,1079,1.0074
144,67,1.0597
142,116,1.0862
141,1084,1.3782
140,332,1.0693
139,1406,1.0462
138,1866,1.0263
137,555,1.0198
136,1551,1.0393
135,1395,1.0975
134,500,0.9920
133,1143,1.0647
132,232,1.1422
131,764,1.0681
130,618,1.1165
129,755,1.0477
128,258,1.1395
127,1319,1.0447
126,464,1.2823
124,254,1.2362
123,1296,1.1566
122,1077,1.2228
121,291,1.3299
119,1031,1.1736
117,1849,1.4462
116,1690,1.1438
115,1254,1.0654
114,1418,1.0755
113,1052,1.2376
112,702,1.1339

文/程忠 浏览次数:0次   2024-02-29 10:02:53

相关阅读


评论:
点击刷新

↓ 广告开始-头部带绿为生活 ↓
↑ 广告结束-尾部支持多点击 ↑