diff options
Diffstat (limited to 'gnu')
-rw-r--r-- | gnu/packages/machine-learning.scm | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm index 6fc693a47b..fae6d244b0 100644 --- a/gnu/packages/machine-learning.scm +++ b/gnu/packages/machine-learning.scm @@ -102,6 +102,7 @@ #:use-module (gnu packages statistics) #:use-module (gnu packages sqlite) #:use-module (gnu packages swig) + #:use-module (gnu packages time) #:use-module (gnu packages tls) #:use-module (gnu packages video) #:use-module (gnu packages web) @@ -3518,6 +3519,128 @@ validating answers, managing hierarchical prompts and providing error feedback.") (license license:expat))) +(define-public python-pytorch-lightning + (package + (name "python-pytorch-lightning") + (version "2.0.2") + (source (origin + (method git-fetch) + (uri (git-reference + (url "https://github.com/Lightning-AI/lightning") + (commit version))) + (file-name (git-file-name name version)) + (sha256 + (base32 + "1w4lajiql4y5nnhqf6i5wii1mrwnhp5f4bzbwdzb5zz0d0lysb1i")))) + (build-system pyproject-build-system) + (arguments + (list + #:test-flags + '(list "-m" "not cloud and not tpu" "tests/tests_pytorch" + ;; we don't have onnxruntime + "--ignore=tests/tests_pytorch/models/test_onnx.py" + + ;; We don't have tensorboard, so we skip all those tests that + ;; require it for logging. + "--ignore=tests/tests_pytorch/checkpointing/test_model_checkpoint.py" + "--ignore=tests/tests_pytorch/loggers/test_all.py" + "--ignore=tests/tests_pytorch/loggers/test_logger.py" + "--ignore=tests/tests_pytorch/loggers/test_tensorboard.py" + "--ignore=tests/tests_pytorch/models/test_cpu.py" + "--ignore=tests/tests_pytorch/models/test_hparams.py" + "--ignore=tests/tests_pytorch/models/test_restore.py" + "--ignore=tests/tests_pytorch/profilers/test_profiler.py" + "--ignore=tests/tests_pytorch/trainer/flags/test_fast_dev_run.py" + "--ignore=tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py" + "--ignore=tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py" + "--ignore=tests/tests_pytorch/trainer/properties/test_loggers.py" + "--ignore=tests/tests_pytorch/trainer/properties/test_log_dir.py" + "--ignore=tests/tests_pytorch/trainer/test_trainer.py" + + ;; This needs internet access + "--ignore=tests/tests_pytorch/helpers/test_models.py" + "--ignore=tests/tests_pytorch/helpers/test_datasets.py" + "--ignore=tests/tests_pytorch/helpers/datasets.py" + + ;; We have no legacy checkpoints + "--ignore=tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py" + + ;; TypeError: _FlakyPlugin._make_test_flaky() got an unexpected keyword argument 'reruns' + "--ignore=tests/tests_pytorch/models/test_amp.py" + "--ignore=tests/tests_pytorch/profilers/test_profiler.py" + + "--ignore=tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py" + + "-k" + (string-append + ;; We don't have tensorboard + "not test_property_logger" + " and not test_cli_logger_shorthand" + ;; Something wrong with Flaky + " and not test_servable_module_validator_with_trainer")) + #:phases + '(modify-phases %standard-phases + (add-after 'unpack 'patch-version-detection + (lambda _ + ;; We do have pytorch 1.13.1, but the version comparison fails. + (substitute* "src/lightning/fabric/utilities/imports.py" + (("_TORCH_GREATER_EQUAL_1_13 =.*") + "_TORCH_GREATER_EQUAL_1_13 = True\n")))) + (add-before 'build 'pre-build + (lambda _ (setenv "PACKAGE_NAME" "lightning"))) + (add-after 'install 'pre-build-pytorch + (lambda _ + ;; pyproject-build-system only tolerates unicycles. + (for-each delete-file (find-files "dist" "\\.whl")) + (setenv "PACKAGE_NAME" "pytorch"))) + (add-after 'pre-build-pytorch 'build-pytorch + (assoc-ref %standard-phases 'build)) + (add-after 'build-pytorch 'install-pytorch + (assoc-ref %standard-phases 'install)) + (add-before 'check 'pre-check + (lambda _ + ;; We don't have Tensorboard + (substitute* "tests/tests_pytorch/test_cli.py" + ((" TensorBoardLogger\\(\".\"\\)") ""))))))) + (propagated-inputs + (list python-arrow + python-beautifulsoup4 + python-croniter + python-dateutils + python-deepdiff + python-fastapi-for-pytorch-lightning + python-fsspec + python-inquirer + python-jsonargparse + python-lightning-cloud + python-lightning-utilities + python-numpy + python-packaging + python-pytorch + python-pyyaml + python-starsessions-for-pytorch-lightning + python-torchmetrics + python-torchvision + python-tqdm + python-traitlets + python-typing-extensions)) + (native-inputs + (list python-aiohttp + python-cloudpickle + python-coverage + python-flaky + python-pympler + python-pytest + python-psutil + python-requests-mock + python-scikit-learn)) + (home-page "https://lightning.ai/") + (synopsis "Deep learning framework to train, deploy, and ship AI products") + (description + "PyTorch Lightning is just organized PyTorch; Lightning disentangles +PyTorch code to decouple the science from the engineering.") + (license license:asl2.0))) + (define-public python-torchmetrics (package (name "python-torchmetrics") |