diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ae319c7..e4cf3e4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,5 +19,5 @@ again. All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult -[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +[GitHub Help](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) for more information on using pull requests. diff --git a/enn/colabs/enn_demo.ipynb b/enn/colabs/enn_demo.ipynb index b5db969..5acccde 100644 --- a/enn/colabs/enn_demo.ipynb +++ b/enn/colabs/enn_demo.ipynb @@ -104,25 +104,19 @@ "\n", "import warnings\n", "\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "\n", + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "#@title Development imports\n", - "from typing import Callable, NamedTuple\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "import plotnine as gg\n", - "\n", "from acme.utils.loggers.terminal import TerminalLogger\n", "import dataclasses\n", - "import chex\n", - "import haiku as hk\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import optax\n", - "import tensorflow as tf\n", - "import tensorflow_datasets as tfds" + "import optax" ] }, { @@ -135,15 +129,11 @@ "outputs": [], "source": [ "#@title ENN imports\n", - "import enn\n", "from enn import losses\n", "from enn import networks\n", "from enn import supervised\n", - "from enn import base\n", "from enn import data_noise\n", - "from enn import utils\n", - "from enn.supervised import classification_data\n", - "from enn.supervised import regression_data\n" + "from enn.supervised import classification_data, regression_data" ] }, { @@ -173,6 +163,7 @@ " learning_rate: float = 1e-3\n", " noise_std: float = 0.1\n", "\n", + "\n", "FLAGS = Config()" ] }, @@ -202,7 +193,7 @@ "# Logger\n", "logger = TerminalLogger('supervised_regression')\n", "\n", - "# Create Ensemble ENN with a prior network \n", + "# Create Ensemble ENN with a prior network\n", "enn = networks.MLPEnsembleMatchedPrior(\n", " output_sizes=[50, 50, 1],\n", " dummy_input=next(dataset).x,\n", @@ -211,11 +202,11 @@ " seed=FLAGS.seed,\n", ")\n", "\n", - "# L2 loss on perturbed outputs \n", + "# L2 loss on perturbed outputs\n", "noise_fn = data_noise.GaussianTargetNoise(enn, FLAGS.noise_std, FLAGS.seed)\n", "single_loss = losses.add_data_noise(losses.L2Loss(), noise_fn)\n", "loss_fn = losses.average_single_index_loss(single_loss, FLAGS.num_index_samples)\n", - " \n", + "\n", "# Optimizer\n", "optimizer = optax.adam(FLAGS.learning_rate)\n", "\n", diff --git a/enn/colabs/epinet_demo.ipynb b/enn/colabs/epinet_demo.ipynb index ed1ff08..c47723d 100644 --- a/enn/colabs/epinet_demo.ipynb +++ b/enn/colabs/epinet_demo.ipynb @@ -78,26 +78,20 @@ "\n", "import warnings\n", "\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "\n", + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "#@title Development imports\n", - "from typing import Callable, NamedTuple\n", - "\n", - "import numpy as np\n", "import pandas as pd\n", "import plotnine as gg\n", - "\n", - "from acme.utils.loggers.terminal import TerminalLogger\n", - "import dataclasses\n", - "import chex\n", - "import haiku as hk\n", "import jax\n", - "import jax.numpy as jnp\n", - "import optax\n", - "import dill\n", - "import tensorflow as tf\n", - "import tensorflow_datasets as tfds" + "import dill" ] }, { @@ -110,12 +104,8 @@ "outputs": [], "source": [ "#@title ENN imports\n", - "import enn\n", - "from enn import datasets\n", - "from enn.checkpoints import base as checkpoint_base\n", "from enn.networks.epinet import base as epinet_base\n", "from enn.checkpoints import utils\n", - "from enn.checkpoints import imagenet\n", "from enn.checkpoints import catalog\n", "from enn import metrics as enn_metrics" ]