# test game # for i in range(1,16): # print(fizz_buzz_decode(i, fizz_buzz_encode(i)), end=' ')
我们首先定义模型的输入与输出(训练数据)
1 2 3 4 5 6 7 8 9
import torch import torch.nn as nn import numpy as np
defnum_bin_encode(num): return np.array([num >> d & 1for d in range(NUM_DIGITES)])
trX = torch.Tensor([num_bin_encode(i) for i in range(101, 2 ** NUM_DIGITES)]) trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITES)])
然后我们用PyTorch定义模型
1 2 3 4 5 6 7
NUM_HIDDEN = 100
model = nn.Sequential( nn.Linear(NUM_DIGITES, NUM_HIDDEN), nn.ReLU(), nn.Linear(NUM_HIDDEN, 4) )
BATCH_SIZE = 128 for epoch in range(10000): for start in range(0, len(trX), BATCH_SIZE): end = start + BATCH_SIZE batchX = trX[start:end] batchY = trY[start:end]
y_pre = model(batchX) loss = loss_fn(y_pre, batchY) optimizer.zero_grad() loss.backward() optimizer.step() loss = loss_fn(model(trX), trY).item() print('epoch:', epoch, 'loss:', loss)
testX = torch.Tensor([num_bin_encode(i) for i in range(1, 101)]) with torch.no_grad(): testY = model(testX) predictions = zip(range(1, 101), testY.max(1)[1])
print([fizz_buzz_decode(num, index) for num, index in predictions])