This is the official repository for the paper "Flora: Low-Rank Adapters Are Secretly Gradient Compressors" in ICML 2024.
-
Updated
Jul 1, 2024 - Python
This is the official repository for the paper "Flora: Low-Rank Adapters Are Secretly Gradient Compressors" in ICML 2024.
Implementation of PSGD optimizer in JAX
JAX implementations of various deep reinforcement learning algorithms.
Tensor Networks for Machine Learning
jQMC code implements two real-space ab initio quantum Monte Carlo (QMC) methods. Variatioinal Monte Carlo (VMC) and lattice regularized diffusion Monte Carlo (LRDMC) methods. jQMC achieves high-performance computations especially on GPUs.
Goal-conditioned reinforcement learning like 🔥
JAX/Flax implementation of finite-size scaling
Training methodologies for autoregressive neural operators/emulators in JAX.
EEG task classification with CNN, LSTM, CNN-LSTM, and GAN augmentation across TensorFlow, PyTorch, and JAX.
An Optax-based JAX implementation of the IVON optimizer for large-scale VI training of NNs (ICML'24 spotlight)
JAX implementation of Classical and Quantum Algorithms for Orthogonal Neural Networks by (Kerenidis et al., 2021)
dLLM training implementation on pure jax/flax (w/o pytorch) for Google TPUs(v4/v5e/v6e). #TPUSprint #TRC
H-Former is a VAE for generating in-between fonts (or combining fonts). Its encoder uses a Point net and transformer to compute a code vector of glyph. Its decoder is composed of multiple independent decoders which act on a code vector to reconstruct a point cloud representing a glpyh.
Variational Graph Autoencoder implemented using Jax & Jraph
dm-haiku implementation of hyperbolic neural networks
A Simplistic trainer for Flax
An implementation of adan optimizer for optax
Add a description, image, and links to the optax topic page so that developers can more easily learn about it.
To associate your repository with the optax topic, visit your repo's landing page and select "manage topics."