### Training a Restricted Boltzmann Machine using pytorch

We start by importing the needed libraries:
``````
import torch
import numpy as np
```
```
We then set up the dimensions of the visible and hidden layers:
``````
nv = 3
nh = 3
n = nv + nh
```
```
Next we define the hyperparameters of our training method:
``````
learning_rate = 0.01
epochs = 10
batch_size = 100
k = 3
```
```
We first need to generate some training data to train our model.
For this simple example we will train the RBM on a simple training set consisting of four different states.
``````
N = 10000
states = np.array([[1,0,0],[1,1,0],[1,0,1],[1,1,1]])
idx = np.random.choice(np.arange(len(states)), N)
data = states[idx]
data = torch.from_numpy(data)
``````
Now we code the main RBM class that will define the RBM architecture:
```	```
class RBM(torch.nn.Module):

def __init__(self, nv, nh):

super(RBM, self).__init__()

self.w = torch.nn.Parameter(torch.randn(nv, nh) * 0.01)
self.a = torch.nn.Parameter(torch.zeros(nv))
self.b = torch.nn.Parameter(torch.zeros(nh))

def sample_h(self, v):

phv = torch.sigmoid(torch.matmul(v, self.w) + self.b)
h = torch.bernoulli(phv)

return h, phv

def sample_v(self, h):

pvh = torch.sigmoid(torch.matmul(h, self.w.t()) + self.a)
v = torch.bernoulli(pvh)

return v, pvh

def forward(self, v):
h, phv = self.sample_h(v)

# gibbs sampling
for i in range(k):
v, pvh = self.sample_v(phv)
h, phv = self.sample_h(v)
v, pvh = self.sample_v(phv)

return v

def free_energy(self, v):

vt = torch.matmul(v, self.a)
ht = torch.sum(torch.log(1 + torch.exp(torch.matmul(v, self.w) + self.b)), dim = 1)

return -(vt + ht)
```
```
Next we define the main training loop. Note that we make sure to send each function call to the proper device.
``````
def train(rbm, train_loader, learning_rate, k, training_epochs):

for epoch in range(training_epochs):

epoch_cost = 0.

batch = batch.view(-1, nv).to(device)

v = rbm.forward(batch).to(device)

cost = torch.mean(rbm.free_energy(batch)) - torch.mean(rbm.free_energy(v))
cost = cost.to(device)

epoch_cost += cost.item()
``````