From 6416e77280e64723bad17b6dbec8483de09271c3 Mon Sep 17 00:00:00 2001 From: Ricardo Wurmus Date: Sat, 9 Jul 2022 10:20:07 +0200 Subject: gnu: Add python-pyro-ppl. * gnu/packages/machine-learning.scm (python-pyro-ppl): New variable. --- gnu/packages/machine-learning.scm | 75 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) (limited to 'gnu') diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm index c1bdb8b31d..b19af8a1d5 100644 --- a/gnu/packages/machine-learning.scm +++ b/gnu/packages/machine-learning.scm @@ -73,6 +73,7 @@ #:use-module (gnu packages image) #:use-module (gnu packages image-processing) #:use-module (gnu packages imagemagick) + #:use-module (gnu packages jupyter) #:use-module (gnu packages libffi) #:use-module (gnu packages linux) #:use-module (gnu packages llvm) @@ -3299,3 +3300,77 @@ and Numpy.") (synopsis "Generic API for dispatch to Pyro backends.") (description "This package provides a generic API for dispatch to Pyro backends.") (license license:asl2.0))) + +(define-public python-pyro-ppl + (package + (name "python-pyro-ppl") + (version "1.8.1") + ;; The sources on pypi don't include tests. + (source + (origin + (method git-fetch) + (uri (git-reference + (url "https://github.com/pyro-ppl/pyro") + (commit version))) + (file-name (git-file-name name version)) + (sha256 + (base32 "0ns20mr8qgjshzbplrfzaz1xhb9ldbgvrj2rzlsxvns2bi1ddyl5")))) + (build-system python-build-system) + (arguments + `(#:phases + (modify-phases %standard-phases + (replace 'check + (lambda* (#:key tests? #:allow-other-keys) + ;; This tests features that are only implemented when non-free + ;; software is available (Intel MKL or CUDA). + (for-each delete-file + (list "tests/distributions/test_spanning_tree.py" + "tests/infer/mcmc/test_mcmc_api.py")) + + ;; Four test_gamma_elbo tests fail with bad values for unknown + ;; reasons. + (delete-file "tests/distributions/test_rejector.py") + ;; This test fails sometimes. + (delete-file "tests/optim/test_optim.py") + (invoke "pytest" "-vv" "--stage=unit")))))) + (propagated-inputs + (list python-numpy + python-opt-einsum + python-pyro-api + python-pytorch + python-tqdm)) + (native-inputs + (list ninja + jupyter + python-black + python-flake8 + python-graphviz + python-isort + python-lap + python-matplotlib + python-mypy + python-nbformat + python-nbsphinx + python-nbstripout + python-nbval + python-pandas + python-pillow + python-pypandoc + python-pytest + python-pytest-cov + python-pytest-xdist + python-scikit-learn + python-scipy + python-seaborn + python-sphinx + python-sphinx-rtd-theme + python-torchvision + python-visdom + python-wget + python-yapf)) + (home-page "https://pyro.ai") + (synopsis "Python library for probabilistic modeling and inference") + (description + "This package provides a Python library for probabilistic modeling and +inference.") + (license license:asl2.0))) -- cgit v1.2.3