Training a Restricted Boltzmann Machine using pytorch
We start by importing the needed libraries:
import torch
import numpy as np
We then set up the dimensions of the visible and hidden layers:
nv = 3
nh = 3
n = nv + nh
Next we define the hyperparameters of our training method:
learning_rate = 0.01
epochs = 10
batch_size = 100
k = 3
We first need to generate some training data to train our model.
For this simple example we will train the RBM on a simple training set consisting of four different states.
N = 10000
states = np.array([[1,0,0],[1,1,0],[1,0,1],[1,1,1]])
idx = np.random.choice(np.arange(len(states)), N)
data = states[idx]
data = torch.from_numpy(data)
train_loader = torch.utils.data.DataLoader(dataset = data.to(torch.float), batch_size = batch_size, shuffle = True)
Now we code the main RBM class that will define the RBM architecture:
class RBM(torch.nn.Module):
def __init__(self, nv, nh):
super(RBM, self).__init__()
self.w = torch.nn.Parameter(torch.randn(nv, nh) * 0.01)
self.a = torch.nn.Parameter(torch.zeros(nv))
self.b = torch.nn.Parameter(torch.zeros(nh))
def sample_h(self, v):
phv = torch.sigmoid(torch.matmul(v, self.w) + self.b)
h = torch.bernoulli(phv)
return h, phv
def sample_v(self, h):
pvh = torch.sigmoid(torch.matmul(h, self.w.t()) + self.a)
v = torch.bernoulli(pvh)
return v, pvh
def forward(self, v):
h, phv = self.sample_h(v)
# gibbs sampling
for i in range(k):
v, pvh = self.sample_v(phv)
h, phv = self.sample_h(v)
v, pvh = self.sample_v(phv)
return v
def free_energy(self, v):
vt = torch.matmul(v, self.a)
ht = torch.sum(torch.log(1 + torch.exp(torch.matmul(v, self.w) + self.b)), dim = 1)
return -(vt + ht)
Next we define the main training loop. Note that we make sure to send each function call to the proper device.
def train(rbm, train_loader, learning_rate, k, training_epochs):
optimizer = torch.optim.Adam(rbm.parameters(), lr=learning_rate)
for epoch in range(training_epochs):
epoch_cost = 0.
for batch in train_loader:
batch = batch.view(-1, nv).to(device)
v = rbm.forward(batch).to(device)
cost = torch.mean(rbm.free_energy(batch)) - torch.mean(rbm.free_energy(v))
cost = cost.to(device)
epoch_cost += cost.item()
optimizer.zero_grad()
cost.backward()
optimizer.step()
print('Epoch [{}/{}], cost: {:.4f}'.format(epoch+1, training_epochs, epoch_cost))
And finally we actually train our RBM:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = RBM(nv, nh).to(device)
train(model, train_loader, learning_rate, k, epochs)