Convert Figma logo to code with AI

naturomics logoCapsNet-Tensorflow

A Tensorflow implementation of CapsNet(Capsules Net) in paper Dynamic Routing Between Capsules

3,801
1,162
3,801
29

Top Related Projects

A Keras implementation of CapsNet in NIPS2017 paper "Dynamic Routing Between Capsules". Now test error = 0.34%.

A PyTorch implementation of the NIPS 2017 paper "Dynamic Routing Between Capsules".

2,912

Models and examples built with TensorFlow

Quick Overview

CapsNet-Tensorflow is an implementation of Capsule Networks (CapsNet) using TensorFlow. It aims to reproduce the results of the original paper "Dynamic Routing Between Capsules" by Sabour et al. The project provides a flexible and extensible codebase for experimenting with CapsNet architectures on various datasets.

Pros

  • Implements the novel CapsNet architecture, which addresses limitations of traditional CNNs
  • Provides a modular and extensible codebase for easy experimentation
  • Includes pre-trained models and supports multiple datasets (MNIST, CIFAR10, SVHN)
  • Well-documented with clear instructions for setup and usage

Cons

  • May require significant computational resources for training on larger datasets
  • Limited to TensorFlow framework, which may not suit all users' preferences
  • Lacks some of the latest optimizations and improvements in CapsNet research
  • May require updates to work with the most recent versions of TensorFlow

Code Examples

  1. Loading and preprocessing data:
from config import cfg
from utils import load_mnist

trainX, trainY, testX, testY = load_mnist(cfg.dataset, cfg.is_training)
  1. Creating the CapsNet model:
from capsNet import CapsNet

model = CapsNet(input_shape=[28, 28, 1],
                n_class=10,
                routing_iterations=3)
  1. Training the model:
model.train(trainX, trainY, testX, testY)
  1. Evaluating the model:
test_acc = model.evaluate(testX, testY)
print(f"Test accuracy: {test_acc}")

Getting Started

  1. Clone the repository:

    git clone https://github.com/naturomics/CapsNet-Tensorflow.git
    cd CapsNet-Tensorflow
    
  2. Install dependencies:

    pip install -r requirements.txt
    
  3. Run the training script:

    python main.py
    
  4. To use a specific configuration, modify the config.py file or pass arguments:

    python main.py --dataset mnist --batch_size 128 --epoch 50
    

Competitor Comparisons

A Keras implementation of CapsNet in NIPS2017 paper "Dynamic Routing Between Capsules". Now test error = 0.34%.

Pros of CapsNet-Keras

  • More user-friendly and easier to understand for those familiar with Keras
  • Better documentation and code organization
  • Supports both TensorFlow and Theano backends

Cons of CapsNet-Keras

  • Slightly slower training and inference times
  • Less flexibility in terms of customization compared to the TensorFlow implementation

Code Comparison

CapsNet-Keras:

def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

CapsNet-Tensorflow:

def squash(vector):
    vec_squared_norm = tf.reduce_sum(tf.square(vector), -2, keepdims=True)
    scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + epsilon)
    vec_squashed = scalar_factor * vector
    return(vec_squashed)

The code comparison shows that both implementations use similar logic for the squash function, but CapsNet-Keras uses Keras backend functions (K) while CapsNet-Tensorflow uses TensorFlow operations directly. This difference reflects the overall approach of each repository, with CapsNet-Keras providing a higher-level abstraction through the Keras API.

A PyTorch implementation of the NIPS 2017 paper "Dynamic Routing Between Capsules".

Pros of capsule-networks

  • More recent implementation, potentially incorporating newer insights
  • Cleaner, more modular code structure
  • Better documentation and explanations of concepts

Cons of capsule-networks

  • Less comprehensive, focusing mainly on MNIST dataset
  • Fewer options for customization and experimentation
  • Less active development and community engagement

Code Comparison

CapsNet-Tensorflow:

def squash(vector):
    vec_squared_norm = reduce_sum(square(vector), -2, keepdims=True)
    scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / sqrt(vec_squared_norm + epsilon)
    vec_squashed = scalar_factor * vector
    return(vec_squashed)

capsule-networks:

def squash(s, axis=-1, epsilon=1e-7):
    squared_norm = K.sum(K.square(s), axis=axis, keepdims=True)
    safe_norm = K.sqrt(squared_norm + K.epsilon())
    squash_factor = squared_norm / (1. + squared_norm)
    unit_vector = s / safe_norm
    return squash_factor * unit_vector

Both implementations provide similar functionality, but capsule-networks uses Keras backend functions, potentially offering better compatibility with the Keras ecosystem.

2,912

Models and examples built with TensorFlow

Pros of models

  • Broader scope, covering multiple machine learning models and applications
  • More comprehensive documentation and examples
  • Active development and regular updates

Cons of models

  • Less focused on CapsNet specifically
  • Potentially more complex to navigate for those solely interested in CapsNet

Code Comparison

CapsNet-Tensorflow:

def squash(vector):
    vector_squared_norm = reduce_sum(square(vector), -2, keepdims=True)
    scalar_factor = vector_squared_norm / (1 + vector_squared_norm) / sqrt(vector_squared_norm + epsilon)
    return scalar_factor * vector

models:

def squash(s, axis=-1, epsilon=1e-7, name=None):
    with tf.name_scope(name, default_name="squash"):
        squared_norm = tf.reduce_sum(tf.square(s), axis=axis, keepdims=True)
        safe_norm = tf.sqrt(squared_norm + epsilon)
        squash_factor = squared_norm / (1. + squared_norm)
        unit_vector = s / safe_norm
        return squash_factor * unit_vector

Summary

While CapsNet-Tensorflow focuses specifically on implementing Capsule Networks, models offers a broader range of machine learning models and applications. models provides more comprehensive documentation and examples, making it potentially more accessible for beginners. However, for those specifically interested in CapsNet, CapsNet-Tensorflow may offer a more focused and streamlined experience. The code comparison shows similar implementations of the squash function, with models offering slightly more detailed naming and organization.

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

CapsNet-Tensorflow

Contributions welcome License Gitter

A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

capsVSneuron

Notes:

  1. The current version supports MNIST and Fashion-MNIST datasets. The current test accuracy for MNIST is 99.64%, and Fashion-MNIST 90.60%, see details in the Results section
  2. See dist_version for multi-GPU support
  3. Here(知乎) is an article explaining my understanding of the paper. It may be helpful in understanding the code.

Important:

If you need to apply CapsNet model to your own datasets or build up a new model with the basic block of CapsNet, please follow my new project CapsLayer, which is an advanced library for capsule theory, aiming to integrate capsule-relevant technologies, provide relevant analysis tools, develop related application examples, and promote the development of capsule theory. For example, you can use capsule layer block in your code easily with the API capsLayer.layers.fully_connected and capsLayer.layers.conv2d

Requirements

  • Python
  • NumPy
  • Tensorflow>=1.3
  • tqdm (for displaying training progress info)
  • scipy (for saving images)

Usage

Step 1. Download this repository with git or click the download ZIP button.

$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow

Step 2. Download MNIST or Fashion-MNIST dataset. In this step, you have two choices:

  • a) Automatic downloading with download_data.py script
$ python download_data.py   (for mnist dataset)
$ python download_data.py --dataset fashion-mnist --save_to data/fashion-mnist (for fashion-mnist dataset)
  • b) Manual downloading with wget or other tools, move and extract dataset into data/mnist or data/fashion-mnist directory, for example:
$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
$ gunzip data/mnist/*.gz

Step 3. Start the training(Using the MNIST dataset by default):

$ python main.py
$ # or training for fashion-mnist dataset
$ python main.py --dataset fashion-mnist
$ # If you need to monitor the training process, open tensorboard with this command
$ tensorboard --logdir=logdir
$ # or use `tail` command on linux system
$ tail -f results/val_acc.csv

Step 4. Calculate test accuracy

$ python main.py --is_training=False
$ # for fashion-mnist dataset
$ python main.py --dataset fashion-mnist --is_training=False

Note: The default parameters of batch size is 128, and epoch 50. You may need to modify the config.py file or use command line parameters to suit your case, e.g. set batch size to 64 and do once test summary every 200 steps: python main.py --test_sum_freq=200 --batch_size=48

Results

The pictures here are plotted by tensorboard and my tool plot_acc.R

  • training loss

total_loss margin_loss reconstruction_loss

Here are the models I trained and my talk and something else:

Baidu Netdisk(password:ahjs)

  • The best val error(using reconstruction)
Routing iteration134
val error0.360.360.41
Paper0.290.25-

test_acc

My simple comments for capsule

  1. A new version neural unit(vector in vector out, not scalar in scalar out)
  2. The routing algorithm is similar to attention mechanism
  3. Anyway, a great potential work, a lot to be built upon

My weChat:

my_wechat

Reference