机器学习 - 03 FizzBuzz Game

FizzBuzz是一个简单的小游戏。游戏规则如下:从1开始往上数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数,说buzz,当遇到15的倍数,就说fizzbuzz,其他情况下则正常数数。

我们可以写一个简单的小程序来决定要返回正常数值还是fizz, buzz 或者 fizzbuzz。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
NUM_DIGITES = 10

def fizz_buzz_encode(num):
if num % 15 == 0: return 3
elif num % 5 == 0: return 2
elif num % 3 == 0: return 1
else: return 0


def fizz_buzz_decode(num, index):
return [str(num), 'fizz', 'buzz', 'fizzbuzz'][index]

# test game
# for i in range(1,16):
# print(fizz_buzz_decode(i, fizz_buzz_encode(i)), end=' ')

我们首先定义模型的输入与输出(训练数据)

1
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn
import numpy as np

def num_bin_encode(num):
return np.array([num >> d & 1 for d in range(NUM_DIGITES)])

trX = torch.Tensor([num_bin_encode(i) for i in range(101, 2 ** NUM_DIGITES)])
trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITES)])

然后我们用PyTorch定义模型

1
2
3
4
5
6
7
NUM_HIDDEN = 100

model = nn.Sequential(
nn.Linear(NUM_DIGITES, NUM_HIDDEN),
nn.ReLU(),
nn.Linear(NUM_HIDDEN, 4)
)
  • 为了让我们的模型学会FizzBuzz这个游戏,我们需要定义一个损失函数,和一个优化算法。
  • 这个优化算法会不断优化(降低)损失函数,使得模型的在该任务上取得尽可能低的损失值。
  • 损失值低往往表示我们的模型表现好,损失值高表示我们的模型表现差。
  • 由于FizzBuzz游戏本质上是一个分类问题,我们选用Cross Entropyy Loss函数。
  • 优化函数我们选用Stochastic Gradient Descent。
1
2
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)

以下是模型的训练代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
BATCH_SIZE = 128
for epoch in range(10000):
for start in range(0, len(trX), BATCH_SIZE):
end = start + BATCH_SIZE
batchX = trX[start:end]
batchY = trY[start:end]

y_pre = model(batchX)
loss = loss_fn(y_pre, batchY)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = loss_fn(model(trX), trY).item()
print('epoch:', epoch, 'loss:', loss)

epoch: 0 loss: 1.1909565925598145
…………
epoch: 9999 loss: 0.0071606868878006935

最后我们用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏

1
2
3
4
5
6
testX = torch.Tensor([num_bin_encode(i) for i in range(1, 101)])
with torch.no_grad():
testY = model(testX)
predictions = zip(range(1, 101), testY.max(1)[1])

print([fizz_buzz_decode(num, index) for num, index in predictions])
['1', '2', 'fizz', 'buzz', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', '25', 'buzz', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', '42', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', 'buzz', 'buzz', 'fizz', '67', 'buzz', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', '86', '87', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz']
1
2
# 计算正确数目, 因为true是int的子类型,true == 1 false == 0
print(np.sum(testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1, 101)])))
93
1
testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])
array([ True,  True,  True, False,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True, False,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False,  True,  True,  True, False,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True, False,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])