Skip to content

LuxDL/Lux.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

GitHub Discussions Latest Docs Stable Docs

CI CI (pre-release) Build status codecov Benchmarks

Downloads Downloads

JET Testing Aqua QA ColPrac: Contributor's Guide on Collaborative Practices for Community Packages Code Style: Blue

Elegant & Performant Scientific Machine Learning in Julia

A Pure Julia Deep Learning Framework designed for Scientific Machine Learning

πŸ’» Installation

import Pkg
Pkg.add("Lux")

Tip

To use Lux online, use Google Colab. The Julia Runtime comes pre-installed with Lux and Reactant!

Packages Stable Version Monthly Downloads Total Downloads Build Status
πŸ“¦ Lux.jl
β”” πŸ“¦ LuxLib.jl
β”” πŸ“¦ LuxCore.jl
β”” πŸ“¦ MLDataDevices.jl
β”” πŸ“¦ WeightInitializers.jl
β”” πŸ“¦ LuxTestUtils.jl
β”” πŸ“¦ LuxCUDA.jl

🀸 Quickstart

Reactant & Enzyme

using Lux, Random, Optimisers, Reactant, Enzyme

rng = Random.default_rng()
Random.seed!(rng, 0)

model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))

dev = reactant_device()

ps, st = Lux.setup(rng, model) |> dev

x = rand(rng, Float32, 128, 2) |> dev

# We need to compile the model before we can use it.
model_forward = @compile model(x, ps, Lux.testmode(st))
model_forward(x, ps, Lux.testmode(st))

# Gradients can be computed using Enzyme
@jit Enzyme.gradient(Reverse, sum ∘ first ∘ Lux.apply, Const(model), x, ps, Const(st))

# All of this can be automated using the TrainState API
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

gs, loss, stats, train_state = Training.single_train_step!(
    AutoEnzyme(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state
)

Native Julia & Zygote

using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Get the device determined by Lux
dev = gpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev

# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))

## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state)

## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)

# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state)

πŸ“š Examples

Look in the examples directory for self-contained usage examples. The documentation has examples sorted into proper categories.

πŸ†˜ Getting Help

For usage related questions, please use Github Discussions which allows questions and answers to be indexed. To report bugs use github issues or even better send in a pull request.

πŸ§‘β€πŸ”¬ Citation

If you found this library to be useful in academic work, then please cite:

@software{pal2023lux,
  author    = {Pal, Avik},
  title     = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
  month     = apr,
  year      = 2023,
  note      = {If you use this software, please cite it as below.},
  publisher = {Zenodo},
  version   = {v1.4.2},
  doi       = {10.5281/zenodo.7808903},
  url       = {https://doi.org/10.5281/zenodo.7808903},
  swhid     = {swh:1:dir:1a304ec3243961314a1cc7c1481a31c4386c4a34;origin=https://doi.org/10.5281/zenodo.7808903;visit=swh:1:snp:e2bbe43b14bde47c4ddf7e637eb7fc7bd10db8c7;anchor=swh:1:rel:2c0c0ff927e7bfe8fc8bc43fd553ab392a6eb403;path=/}
}

@thesis{pal2023efficient,
  title     = {{On Efficient Training \& Inference of Neural Differential Equations}},
  author    = {Pal, Avik},
  year      = {2023},
  school    = {Massachusetts Institute of Technology}
}

Also consider starring our github repo.

πŸ§‘β€πŸ’» Contributing

This section is somewhat incomplete. You can contribute by contributing to finishing this section 😜.

πŸ’Ž Formatting (JuliaFormatter)

Note

Pin JuliaFormatter to v1 until upstream issues with v2 are resolved.

using JuliaFormatter
format(".")

πŸ§ͺ Testing

The full test of Lux.jl takes a long time, here's how to test a portion of the code.

For each @testitem, there are corresponding tags, for example:

@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers]

For example, let's consider the tests for SkipConnection:

@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers] begin
    ...
end

We can test the group to which SkipConnection belongs by testing core_layers. To do so set the LUX_TEST_GROUP environment variable, or rename the tag to further narrow the test scope:

export LUX_TEST_GROUP="core_layers"

Or directly modify the default test tag in runtests.jl:

# const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "core_layers"))

But be sure to restore the default value "all" before submitting the code.

Furthermore if you want to run a specific test based on the name of the testset, you can use TestEnv.jl as follows. Start with activating the Lux environment and then run the following:

using TestEnv; TestEnv.activate(); using ReTestItems;

# Assuming you are in the main directory of Lux
ReTestItems.runtests("tests/"; name = "NAME OF THE TEST")

For the SkipConnection tests that would be:

ReTestItems.runtests("tests/"; name = "SkipConnection")
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy