Skip to content

Commit fcc682b

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE][Ez]: Fully type nn.utils.clip_grad (#154801)
Full types clip_grad and exposed typing annotations that were hidden by a bad decorator Pull Request resolved: #154801 Approved by: https://github.com/jansel
1 parent ed6ae20 commit fcc682b

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torch/nn/utils/clip_grad.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import types
55
import typing
66
import warnings
7-
from typing import cast, Optional, Union
8-
from typing_extensions import deprecated
7+
from typing import Callable, cast, Optional, TypeVar, Union
8+
from typing_extensions import deprecated, ParamSpec, TypeAlias
99

1010
import torch
1111
from torch import Tensor
@@ -19,13 +19,16 @@
1919
__all__: list[str] = []
2020

2121

22-
_tensor_or_tensors = Union[
22+
_tensor_or_tensors: TypeAlias = Union[ # noqa: PYI042
2323
torch.Tensor,
2424
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
2525
]
2626

27+
_P = ParamSpec("_P")
28+
_R = TypeVar("_R")
2729

28-
def _no_grad(func):
30+
31+
def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
2932
"""
3033
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
3134
clip_grad_norm_ and clip_grad_value_ themselves.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy