Source code for hopfield

import numpy as np
seed=42
np.random.seed(seed)
from matplotlib import pyplot as plt
import matplotlib.cm as cm

[docs]class HopfieldNetwork(object): """Implementation of a Hopfield Network using Hebb rule and synchronous pattern recovery. """
[docs] def train(self, train_data): """Train the network using Hebbian learning rule. :param train_data: list of preprocessed training images :type train_data: list of np.ndarray :return: None """ num_data = len(train_data) self.num_neurons = train_data[0].shape[0] W = np.zeros((self.num_neurons, self.num_neurons)) rho = np.sum([np.sum(t) for t in train_data]) / (num_data*self.num_neurons) for i in range(num_data): t = train_data[i] - rho W += np.outer(t, t) diagW = np.diag(np.diag(W)) W = W - diagW W /= num_data self.W = W
[docs] def predict(self, data, num_iter=20, threshold=0): """Recover stored patterns from noisy images. :param data: list of corrupted samples to be reconstructed :type data: list of np.ndarray :param num_iter: number of iterations to run :type num_iter: int, defaults to 20 :param threshold: activation threshold for neurons :type threshold: float, defaults to 0 :return: list of predictions :rtype: list of np.ndarray """ self.num_iter = num_iter self.threshold = threshold copied_data = np.copy(data) preds = [] for i in range(len(data)): preds.append(self.sync_update(copied_data[i])) return preds
[docs] def sync_update(self, init_s): """Synchronous update :param init_s: initial state; the corrupted image :type init_s: np.ndarray :return: predicted state :rtype: np.ndarray """ s = init_s e = self.compute_energy(s) for i in range(self.num_iter): s = np.sign(self.W @ s - self.threshold) e_new = self.compute_energy(s) if e == e_new: return s e = e_new return s
[docs] def compute_energy(self, s): """Compute energy of given state :param s: state for which energy needs to be computed :type s: np.ndarray :return: energy of state s :rtype: float """ return -0.5 * s @ self.W @ s + np.sum(s * self.threshold)
[docs] def plot_weight_matrix(self, figsize=(5,5)): """Plot weights of trained network :param figsize: figsize for matplotlib figure, defaults to (5,5) :type figsize: tuple of ints :return: None """ plt.figure(figsize=figsize) w_mat = plt.imshow(self.W, cmap=cm.coolwarm) plt.colorbar(w_mat, fraction=0.046, pad=0.04) plt.title("Weights Matrix for Trained Network") plt.tight_layout() plt.show()