diff --git a/arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py b/arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py index 36b20c3..c5f42c1 100644 --- a/arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py +++ b/arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py @@ -5,8 +5,8 @@ from typing import Any from arrayfire_wrapper.lib._broadcast import bcast_var -from arrayfire_wrapper.lib.create_and_modify_array.manage_array import release_array - +from arrayfire_wrapper.lib.create_and_modify_array.manage_array import release_array, retain_array +from arrayfire_wrapper.defines import AFArray class _IndexSequence(ctypes.Structure): """ @@ -186,7 +186,7 @@ class IndexStructure(ctypes.Structure): ----------- idx: key - - If of type af.Array, self.idx.arr = idx, self.isSeq = False + - If of type AFArray, self.idx.arr = idx, self.isSeq = False - If of type af.ParallelRange, self.idx.seq = idx, self.isBatch = True - Default:, self.idx.seq = af._IndexSequence(idx) @@ -197,26 +197,21 @@ class IndexStructure(ctypes.Structure): """ - def __init__(self, idx: Any) -> None: + def __init__(self, idx: int | slice | AFArray) -> None: self.idx = _IndexUnion() self.isBatch = False self.isSeq = True - # BUG cyclic reimport - # if isinstance(idx, Array): - # if idx.dtype == af_bool: - # self.idx.arr = everything.where(idx.arr) - # else: - # self.idx.arr = everything.retain_array(idx.arr) - - # self.isSeq = False - - if isinstance(idx, ParallelRange): + if isinstance(idx, int) or isinstance(idx, slice): + self.idx.seq = _IndexSequence(idx) + elif isinstance(idx, ParallelRange): self.idx.seq = idx self.isBatch = True - + elif isinstance(idx, AFArray): + self.idx.arr = retain_array(idx) + self.isSeq = False else: - self.idx.seq = _IndexSequence(idx) + raise IndexError("Invalid type while indexing arrayfire.array") def __del__(self) -> None: if not self.isSeq: @@ -247,7 +242,7 @@ def __setitem__(self, idx: int, value: IndexStructure) -> None: self.idxs[idx] = value -def get_indices(key: int | slice | tuple[int | slice, ...]) -> CIndexStructure: # BUG +def get_indices(key: int | slice | tuple[int | slice | AFArray, ...] | AFArray) -> CIndexStructure: # BUG indices = CIndexStructure() if isinstance(key, tuple): pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy