diff --git a/docarray/documents/text.py b/docarray/documents/text.py index c6e6645f4e1..557bffa02e3 100644 --- a/docarray/documents/text.py +++ b/docarray/documents/text.py @@ -24,7 +24,7 @@ class TextDoc(BaseDoc): from docarray.documents import TextDoc # use it directly - txt_doc = TextDoc(url='http://www.jina.ai/') + txt_doc = TextDoc(url='https://www.gutenberg.org/files/1065/1065-0.txt') txt_doc.text = txt_doc.url.load() # model = MyEmbeddingModel() # txt_doc.embedding = model(txt_doc.text) @@ -51,7 +51,7 @@ class MyText(TextDoc): second_embedding: Optional[AnyEmbedding] - txt_doc = MyText(url='http://www.jina.ai/') + txt_doc = MyText(url='https://www.gutenberg.org/files/1065/1065-0.txt') txt_doc.text = txt_doc.url.load() # model = MyEmbeddingModel() # txt_doc.embedding = model(txt_doc.text) @@ -93,8 +93,8 @@ class MultiModalDoc(BaseDoc): ```python from docarray.documents import TextDoc - doc = TextDoc(text='This is the main text', url='exampleurl.com') - doc2 = TextDoc(text='This is the main text', url='exampleurl.com') + doc = TextDoc(text='This is the main text', url='exampleurl.com/file') + doc2 = TextDoc(text='This is the main text', url='exampleurl.com/file') doc == 'This is the main text' # True doc == doc2 # True diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index 6d930aa53f3..28fd8005ad8 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -1,8 +1,9 @@ +import mimetypes import os import urllib import urllib.parse import urllib.request -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, List, Optional, Type, TypeVar, Union import numpy as np from pydantic import AnyUrl as BaseAnyUrl @@ -20,6 +21,8 @@ T = TypeVar('T', bound='AnyUrl') +mimetypes.init([]) + @_register_proto(proto_type_name='any_url') class AnyUrl(BaseAnyUrl, AbstractType): @@ -27,6 +30,17 @@ class AnyUrl(BaseAnyUrl, AbstractType): False # turn off host requirement to allow passing of local paths as URL ) + @classmethod + def mime_type(cls) -> str: + """Returns the mime type associated with the class.""" + raise NotImplementedError + + @classmethod + def extra_extensions(cls) -> List[str]: + """Returns a list of allowed file extensions for the class + that are not covered by the mimetypes library.""" + raise NotImplementedError + def _to_node_protobuf(self) -> 'NodeProto': """Convert Document into a NodeProto protobuf message. This function should be called when the Document is nested into another Document that need to @@ -38,6 +52,48 @@ def _to_node_protobuf(self) -> 'NodeProto': return NodeProto(text=str(self), type=self._proto_type_name) + @staticmethod + def _get_url_extension(url: str) -> str: + """ + Extracts and returns the file extension from a given URL. + If no file extension is present, the function returns an empty string. + + + :param url: The URL to extract the file extension from. + :return: The file extension without the period, if one exists, + otherwise an empty string. + """ + + parsed_url = urllib.parse.urlparse(url) + ext = os.path.splitext(parsed_url.path)[1] + ext = ext[1:] if ext.startswith('.') else ext + return ext + + @classmethod + def is_extension_allowed(cls, value: Any) -> bool: + """ + Check if the file extension of the URL is allowed for this class. + First, it guesses the mime type of the file. If it fails to detect the + mime type, it then checks the extra file extensions. + Note: This method assumes that any URL without an extension is valid. + + :param value: The URL or file path. + :return: True if the extension is allowed, False otherwise + """ + if cls is AnyUrl: + return True + + url_parts = value.split('?') + extension = cls._get_url_extension(value) + if not extension: + return True + + mimetype, _ = mimetypes.guess_type(url_parts[0]) + if mimetype and mimetype.startswith(cls.mime_type()): + return True + + return extension in cls.extra_extensions() + @classmethod def validate( cls: Type[T], @@ -61,10 +117,12 @@ def validate( url = super().validate(abs_path, field, config) # basic url validation - if input_is_relative_path: - return cls(str(value), scheme=None) - else: - return cls(str(url), scheme=None) + if not cls.is_extension_allowed(value): + raise ValueError( + f"The file '{value}' is not in a valid format for class '{cls.__name__}'." + ) + + return cls(str(value if input_is_relative_path else url), scheme=None) @classmethod def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index a84a68754ee..bd71a68b824 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -1,10 +1,11 @@ import warnings -from typing import Optional, Tuple, TypeVar +from typing import List, Optional, Tuple, TypeVar from docarray.typing import AudioNdArray from docarray.typing.bytes.audio_bytes import AudioBytes from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl +from docarray.typing.url.mimetypes import AUDIO_MIMETYPE from docarray.utils._internal.misc import is_notebook T = TypeVar('T', bound='AudioUrl') @@ -17,6 +18,18 @@ class AudioUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return AUDIO_MIMETYPE + + @classmethod + def extra_extensions(cls) -> List[str]: + """ + Returns a list of additional file extensions that are valid for this class + but cannot be identified by the mimetypes library. + """ + return [] + def load(self: T) -> Tuple[AudioNdArray, int]: """ Load the data from the url into an [`AudioNdArray`][docarray.typing.AudioNdArray] diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 43758cf7436..ffbeef15098 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -1,10 +1,11 @@ import warnings -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar from docarray.typing import ImageBytes from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.image import ImageNdArray from docarray.typing.url.any_url import AnyUrl +from docarray.typing.url.mimetypes import IMAGE_MIMETYPE from docarray.utils._internal.misc import is_notebook if TYPE_CHECKING: @@ -20,6 +21,18 @@ class ImageUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return IMAGE_MIMETYPE + + @classmethod + def extra_extensions(cls) -> List[str]: + """ + Returns a list of additional file extensions that are valid for this class + but cannot be identified by the mimetypes library. + """ + return [] + def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image': """ Load the image from the bytes into a `PIL.Image.Image` instance diff --git a/docarray/typing/url/mimetypes.py b/docarray/typing/url/mimetypes.py new file mode 100644 index 00000000000..824f1c3150e --- /dev/null +++ b/docarray/typing/url/mimetypes.py @@ -0,0 +1,94 @@ +TEXT_MIMETYPE = 'text' +AUDIO_MIMETYPE = 'audio' +IMAGE_MIMETYPE = 'image' +OBJ_MIMETYPE = 'application/x-tgif' +VIDEO_MIMETYPE = 'video' + +MESH_EXTRA_EXTENSIONS = [ + '3ds', + '3mf', + 'ac', + 'ac3d', + 'amf', + 'assimp', + 'bvh', + 'cob', + 'collada', + 'ctm', + 'dxf', + 'e57', + 'fbx', + 'gltf', + 'glb', + 'ifc', + 'lwo', + 'lws', + 'lxo', + 'md2', + 'md3', + 'md5', + 'mdc', + 'm3d', + 'mdl', + 'ms3d', + 'nff', + 'obj', + 'off', + 'pcd', + 'pod', + 'pmd', + 'pmx', + 'ply', + 'q3o', + 'q3s', + 'raw', + 'sib', + 'smd', + 'stl', + 'ter', + 'terragen', + 'vtk', + 'vrml', + 'x3d', + 'xaml', + 'xgl', + 'xml', + 'xyz', + 'zgl', + 'vta', +] + +TEXT_EXTRA_EXTENSIONS = ['md', 'log'] + +POINT_CLOUD_EXTRA_EXTENSIONS = [ + 'ascii', + 'bin', + 'b3dm', + 'bpf', + 'dp', + 'dxf', + 'e57', + 'fls', + 'fls', + 'glb', + 'ply', + 'gpf', + 'las', + 'obj', + 'osgb', + 'pcap', + 'pcd', + 'pdal', + 'pfm', + 'ply', + 'ply2', + 'pod', + 'pods', + 'pnts', + 'ptg', + 'ptx', + 'pts', + 'rcp', + 'xyz', + 'zfs', +] diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 86da87790e6..8e7f40cfda7 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -1,7 +1,8 @@ -from typing import Optional, TypeVar +from typing import List, Optional, TypeVar from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl +from docarray.typing.url.mimetypes import TEXT_EXTRA_EXTENSIONS, TEXT_MIMETYPE T = TypeVar('T', bound='TextUrl') @@ -13,6 +14,18 @@ class TextUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return TEXT_MIMETYPE + + @classmethod + def extra_extensions(cls) -> List[str]: + """ + Returns a list of additional file extensions that are valid for this class + but cannot be identified by the mimetypes library. + """ + return TEXT_EXTRA_EXTENSIONS + def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str: """ Load the text file into a string. diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index 70f32eb5581..84645e8ae42 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -1,10 +1,11 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar import numpy as np from pydantic import parse_obj_as from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.ndarray import NdArray +from docarray.typing.url.mimetypes import MESH_EXTRA_EXTENSIONS from docarray.typing.url.url_3d.url_3d import Url3D if TYPE_CHECKING: @@ -20,6 +21,14 @@ class Mesh3DUrl(Url3D): Can be remote (web) URL, or a local file path. """ + @classmethod + def extra_extensions(cls) -> List[str]: + """ + Returns a list of additional file extensions that are valid for this class + but cannot be identified by the mimetypes library. + """ + return MESH_EXTRA_EXTENSIONS + def load( self: T, skip_materials: bool = True, diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index efe6ce6ae0e..94bbf19b0cc 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -1,10 +1,11 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar import numpy as np from pydantic import parse_obj_as from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.ndarray import NdArray +from docarray.typing.url.mimetypes import POINT_CLOUD_EXTRA_EXTENSIONS from docarray.typing.url.url_3d.url_3d import Url3D if TYPE_CHECKING: @@ -21,6 +22,14 @@ class PointCloud3DUrl(Url3D): Can be remote (web) URL, or a local file path. """ + @classmethod + def extra_extensions(cls) -> List[str]: + """ + Returns a list of additional file extensions that are valid for this class + but cannot be identified by the mimetypes library. + """ + return POINT_CLOUD_EXTRA_EXTENSIONS + def load( self: T, samples: int, diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py index c55c0f954e7..78120d144cd 100644 --- a/docarray/typing/url/url_3d/url_3d.py +++ b/docarray/typing/url/url_3d/url_3d.py @@ -3,6 +3,7 @@ from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl +from docarray.typing.url.mimetypes import OBJ_MIMETYPE from docarray.utils._internal.misc import import_library if TYPE_CHECKING: @@ -18,6 +19,10 @@ class Url3D(AnyUrl, ABC): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return OBJ_MIMETYPE + def _load_trimesh_instance( self: T, force: Optional[str] = None, diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py index 5bd7b1be0b9..e4a623e53af 100644 --- a/docarray/typing/url/video_url.py +++ b/docarray/typing/url/video_url.py @@ -1,9 +1,10 @@ import warnings -from typing import Optional, TypeVar +from typing import List, Optional, TypeVar from docarray.typing.bytes.video_bytes import VideoBytes, VideoLoadResult from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl +from docarray.typing.url.mimetypes import VIDEO_MIMETYPE from docarray.utils._internal.misc import is_notebook T = TypeVar('T', bound='VideoUrl') @@ -16,6 +17,18 @@ class VideoUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ + @classmethod + def mime_type(cls) -> str: + return VIDEO_MIMETYPE + + @classmethod + def extra_extensions(cls) -> List[str]: + """ + Returns a list of additional file extensions that are valid for this class + but cannot be identified by the mimetypes library. + """ + return [] + def load(self: T, **kwargs) -> VideoLoadResult: """ Load the data from the url into a `NamedTuple` of diff --git a/tests/index/weaviate/test_index_get_del_weaviate.py b/tests/index/weaviate/test_index_get_del_weaviate.py index 8c1bd15636e..10ac0acd823 100644 --- a/tests/index/weaviate/test_index_get_del_weaviate.py +++ b/tests/index/weaviate/test_index_get_del_weaviate.py @@ -403,7 +403,7 @@ class MyMultiModalDoc(BaseDoc): def test_index_document_with_bytes(weaviate_client): - doc = ImageDoc(id="1", url="www.foo.com", bytes_=b"foo") + doc = ImageDoc(id="1", url="www.foo.com/file", bytes_=b"foo") index = WeaviateDocumentIndex[ImageDoc]() index.index([doc]) diff --git a/tests/integrations/typing/test_typing_proto.py b/tests/integrations/typing/test_typing_proto.py index ff16c2bc1e0..e6fabf0f7a2 100644 --- a/tests/integrations/typing/test_typing_proto.py +++ b/tests/integrations/typing/test_typing_proto.py @@ -73,7 +73,7 @@ class Mymmdoc(BaseDoc): embedding=np.zeros((100, 1)), any_url='http://jina.ai', image_url='http://jina.ai/bla.jpg', - text_url='http://jina.ai', + text_url='http://jina.ai/file.txt', mesh_url='http://jina.ai/mesh.obj', point_cloud_url='http://jina.ai/mesh.obj', ) diff --git a/tests/units/typing/url/test_any_url.py b/tests/units/typing/url/test_any_url.py index f8b55a3fdac..d6633f1fe8a 100644 --- a/tests/units/typing/url/test_any_url.py +++ b/tests/units/typing/url/test_any_url.py @@ -40,3 +40,20 @@ def test_operators(): assert url != 'aljdñjd' assert 'data' in url assert 'docarray' not in url + + +def test_get_url_extension(): + # Test with a URL with extension + assert AnyUrl._get_url_extension('https://jina.ai/hey.md?model=gpt-4') == 'md' + assert AnyUrl._get_url_extension('https://jina.ai/text.txt') == 'txt' + assert AnyUrl._get_url_extension('bla.jpg') == 'jpg' + + # Test with a URL without extension + assert not AnyUrl._get_url_extension('https://jina.ai') + assert not AnyUrl._get_url_extension('https://jina.ai/?model=gpt-4') + + # Test with a text without extension + assert not AnyUrl._get_url_extension('some_text') + + # Test with empty input + assert not AnyUrl._get_url_extension('') diff --git a/tests/units/typing/url/test_audio_url.py b/tests/units/typing/url/test_audio_url.py index 2e6b46bcabf..36b80e8d0b6 100644 --- a/tests/units/typing/url/test_audio_url.py +++ b/tests/units/typing/url/test_audio_url.py @@ -1,3 +1,4 @@ +import os from typing import Optional import numpy as np @@ -8,6 +9,13 @@ from docarray import BaseDoc from docarray.base_doc.io.json import orjson_dumps from docarray.typing import AudioBytes, AudioTorchTensor, AudioUrl +from docarray.typing.url.mimetypes import ( + OBJ_MIMETYPE, + AUDIO_MIMETYPE, + VIDEO_MIMETYPE, + IMAGE_MIMETYPE, + TEXT_MIMETYPE, +) from docarray.utils._internal.misc import is_tf_available from tests import TOYDATA_DIR @@ -123,3 +131,25 @@ def test_load_bytes(): assert isinstance(audio_bytes, bytes) assert isinstance(audio_bytes, AudioBytes) assert len(audio_bytes) > 0 + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + (AUDIO_MIMETYPE, AUDIO_FILES[0]), + (AUDIO_MIMETYPE, AUDIO_FILES[1]), + (AUDIO_MIMETYPE, REMOTE_AUDIO_FILE), + (IMAGE_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.png')), + (VIDEO_MIMETYPE, os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.html')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.md')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + (OBJ_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != AudioUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(AudioUrl, file_source) + else: + parse_obj_as(AudioUrl, file_source) diff --git a/tests/units/typing/url/test_image_url.py b/tests/units/typing/url/test_image_url.py index 4054c997c80..bb9efe7cd36 100644 --- a/tests/units/typing/url/test_image_url.py +++ b/tests/units/typing/url/test_image_url.py @@ -9,6 +9,14 @@ from docarray.base_doc.io.json import orjson_dumps from docarray.typing import ImageUrl +from docarray.typing.url.mimetypes import ( + OBJ_MIMETYPE, + AUDIO_MIMETYPE, + VIDEO_MIMETYPE, + IMAGE_MIMETYPE, + TEXT_MIMETYPE, +) +from tests import TOYDATA_DIR CUR_DIR = os.path.dirname(os.path.abspath(__file__)) PATH_TO_IMAGE_DATA = os.path.join(CUR_DIR, '..', '..', '..', 'toydata', 'image-data') @@ -174,3 +182,27 @@ def test_validation(path_to_img): url = parse_obj_as(ImageUrl, path_to_img) assert isinstance(url, ImageUrl) assert isinstance(url, str) + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + (IMAGE_MIMETYPE, IMAGE_PATHS['png']), + (IMAGE_MIMETYPE, IMAGE_PATHS['jpg']), + (IMAGE_MIMETYPE, IMAGE_PATHS['jpeg']), + (IMAGE_MIMETYPE, REMOTE_JPG), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.mp3')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.wav')), + (VIDEO_MIMETYPE, os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.html')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.md')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + (OBJ_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != ImageUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(ImageUrl, file_source) + else: + parse_obj_as(ImageUrl, file_source) diff --git a/tests/units/typing/url/test_mesh_url.py b/tests/units/typing/url/test_mesh_url.py index fb83a3362a2..71c354bb435 100644 --- a/tests/units/typing/url/test_mesh_url.py +++ b/tests/units/typing/url/test_mesh_url.py @@ -1,9 +1,18 @@ +import os + import numpy as np import pytest from pydantic.tools import parse_obj_as, schema_json_of from docarray.base_doc.io.json import orjson_dumps from docarray.typing import Mesh3DUrl, NdArray +from docarray.typing.url.mimetypes import ( + OBJ_MIMETYPE, + AUDIO_MIMETYPE, + VIDEO_MIMETYPE, + IMAGE_MIMETYPE, + TEXT_MIMETYPE, +) from tests import TOYDATA_DIR MESH_FILES = { @@ -75,3 +84,28 @@ def test_validation(path_to_file): def test_proto_mesh_url(): uri = parse_obj_as(Mesh3DUrl, REMOTE_OBJ_FILE) uri._to_node_protobuf() + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + (OBJ_MIMETYPE, MESH_FILES['obj']), + (OBJ_MIMETYPE, MESH_FILES['glb']), + (OBJ_MIMETYPE, MESH_FILES['ply']), + (OBJ_MIMETYPE, REMOTE_OBJ_FILE), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.aac')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.mp3')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.ogg')), + (VIDEO_MIMETYPE, os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + (IMAGE_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.png')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.html')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.md')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != Mesh3DUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(Mesh3DUrl, file_source) + else: + parse_obj_as(Mesh3DUrl, file_source) diff --git a/tests/units/typing/url/test_point_cloud_url.py b/tests/units/typing/url/test_point_cloud_url.py index e48404fe9ce..88100928329 100644 --- a/tests/units/typing/url/test_point_cloud_url.py +++ b/tests/units/typing/url/test_point_cloud_url.py @@ -1,9 +1,18 @@ +import os + import numpy as np import pytest from pydantic.tools import parse_obj_as, schema_json_of from docarray.base_doc.io.json import orjson_dumps from docarray.typing import NdArray, PointCloud3DUrl +from docarray.typing.url.mimetypes import ( + OBJ_MIMETYPE, + AUDIO_MIMETYPE, + VIDEO_MIMETYPE, + IMAGE_MIMETYPE, + TEXT_MIMETYPE, +) from tests import TOYDATA_DIR MESH_FILES = { @@ -79,3 +88,28 @@ def test_validation(path_to_file): def test_proto_point_cloud_url(): uri = parse_obj_as(PointCloud3DUrl, REMOTE_OBJ_FILE) uri._to_node_protobuf() + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + (OBJ_MIMETYPE, MESH_FILES['obj']), + (OBJ_MIMETYPE, MESH_FILES['glb']), + (OBJ_MIMETYPE, MESH_FILES['ply']), + (OBJ_MIMETYPE, REMOTE_OBJ_FILE), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.aac')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.mp3')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.ogg')), + (VIDEO_MIMETYPE, os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + (IMAGE_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.png')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.html')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.md')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != PointCloud3DUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(PointCloud3DUrl, file_source) + else: + parse_obj_as(PointCloud3DUrl, file_source) diff --git a/tests/units/typing/url/test_text_url.py b/tests/units/typing/url/test_text_url.py index ebee337ab65..a755344f394 100644 --- a/tests/units/typing/url/test_text_url.py +++ b/tests/units/typing/url/test_text_url.py @@ -6,6 +6,13 @@ from docarray.base_doc.io.json import orjson_dumps from docarray.typing import TextUrl +from docarray.typing.url.mimetypes import ( + OBJ_MIMETYPE, + AUDIO_MIMETYPE, + VIDEO_MIMETYPE, + IMAGE_MIMETYPE, + TEXT_MIMETYPE, +) from tests import TOYDATA_DIR REMOTE_TEXT_FILE = 'https://de.wikipedia.org/wiki/Brixen' @@ -89,3 +96,24 @@ def test_validation(path_to_file): url = parse_obj_as(TextUrl, path_to_file) assert isinstance(url, TextUrl) assert isinstance(url, str) + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + *[(TEXT_MIMETYPE, file) for file in LOCAL_TEXT_FILES], + (TEXT_MIMETYPE, REMOTE_TEXT_FILE), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.aac')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.mp3')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.ogg')), + (IMAGE_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.png')), + (VIDEO_MIMETYPE, os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')), + (OBJ_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != TextUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(TextUrl, file_source) + else: + parse_obj_as(TextUrl, file_source) diff --git a/tests/units/typing/url/test_video_url.py b/tests/units/typing/url/test_video_url.py index 726e66a0cb6..e3583bd5edd 100644 --- a/tests/units/typing/url/test_video_url.py +++ b/tests/units/typing/url/test_video_url.py @@ -1,3 +1,4 @@ +import os from typing import Optional import numpy as np @@ -15,6 +16,13 @@ VideoTorchTensor, VideoUrl, ) +from docarray.typing.url.mimetypes import ( + OBJ_MIMETYPE, + AUDIO_MIMETYPE, + VIDEO_MIMETYPE, + IMAGE_MIMETYPE, + TEXT_MIMETYPE, +) from docarray.utils._internal.misc import is_tf_available from tests import TOYDATA_DIR @@ -146,3 +154,26 @@ def test_load_bytes(): assert isinstance(video_bytes, bytes) assert isinstance(video_bytes, VideoBytes) assert len(video_bytes) > 0 + + +@pytest.mark.parametrize( + 'file_type, file_source', + [ + (VIDEO_MIMETYPE, LOCAL_VIDEO_FILE), + (VIDEO_MIMETYPE, REMOTE_VIDEO_FILE), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.aac')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.mp3')), + (AUDIO_MIMETYPE, os.path.join(TOYDATA_DIR, 'hello.ogg')), + (IMAGE_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.png')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.html')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'test' 'test.md')), + (TEXT_MIMETYPE, os.path.join(TOYDATA_DIR, 'penal_colony.txt')), + (OBJ_MIMETYPE, os.path.join(TOYDATA_DIR, 'test.glb')), + ], +) +def test_file_validation(file_type, file_source): + if file_type != VideoUrl.mime_type(): + with pytest.raises(ValueError): + parse_obj_as(VideoUrl, file_source) + else: + parse_obj_as(VideoUrl, file_source)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: