diff --git a/clarifai/client/model.py b/clarifai/client/model.py index edc3428d..8f872bdc 100644 --- a/clarifai/client/model.py +++ b/clarifai/client/model.py @@ -25,6 +25,7 @@ MODEL_EXPORT_TIMEOUT, RANGE_SIZE, TRAINABLE_MODEL_TYPES) from clarifai.errors import UserError from clarifai.urls.helper import ClarifaiUrlHelper +from clarifai.utils import video_utils from clarifai.utils.logging import logger from clarifai.utils.misc import BackoffIterator, status_is_retryable from clarifai.utils.model_train import (find_and_replace_key, params_parser, @@ -424,14 +425,14 @@ def predict(self, raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}." ) # TODO Use Chunker for inputs len > 128 - self._override_model_version(inference_params, output_config) + model_info = self._get_model_info_for_inference(inference_params, output_config) request = service_pb2.PostModelOutputsRequest( user_app_id=self.user_app_id, model_id=self.id, version_id=self.model_version.id, inputs=inputs, runner_selector=runner_selector, - model=self.model_info) + model=model_info) start_time = time.time() backoff_iterator = BackoffIterator(10) @@ -704,14 +705,14 @@ def generate(self, raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}." ) # TODO Use Chunker for inputs len > 128 - self._override_model_version(inference_params, output_config) + model_info = self._get_model_info_for_inference(inference_params, output_config) request = service_pb2.PostModelOutputsRequest( user_app_id=self.user_app_id, model_id=self.id, version_id=self.model_version.id, inputs=inputs, runner_selector=runner_selector, - model=self.model_info) + model=model_info) start_time = time.time() backoff_iterator = BackoffIterator(10) @@ -925,7 +926,8 @@ def generate_by_url(self, inference_params=inference_params, output_config=output_config) - def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector): + def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector, + model_info: resources_pb2.Model): for inputs in input_iterator: yield service_pb2.PostModelOutputsRequest( user_app_id=self.user_app_id, @@ -933,7 +935,7 @@ def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: version_id=self.model_version.id, inputs=inputs, runner_selector=runner_selector, - model=self.model_info) + model=model_info) def stream(self, inputs: Iterator[List[Input]], @@ -957,8 +959,8 @@ def stream(self, # if not isinstance(inputs, Iterator[List[Input]]): # raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.') - self._override_model_version(inference_params, output_config) - request = self._req_iterator(inputs, runner_selector) + model_info = self._get_model_info_for_inference(inference_params, output_config) + request = self._req_iterator(inputs, runner_selector, model_info) start_time = time.time() backoff_iterator = BackoffIterator(10) @@ -1171,8 +1173,53 @@ def input_generator(): inference_params=inference_params, output_config=output_config) - def _override_model_version(self, inference_params: Dict = {}, output_config: Dict = {}) -> None: - """Overrides the model version. + def stream_by_video_file(self, + filepath: str, + input_type: str = 'video', + compute_cluster_id: str = None, + nodepool_id: str = None, + deployment_id: str = None, + user_id: str = None, + inference_params: Dict = {}, + output_config: Dict = {}): + """ + Stream the model output based on the given video file. + + Converts the video file to a streamable format, streams as bytes to the model, + and streams back the model outputs. + + Args: + filepath (str): The filepath to predict. + input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio. + compute_cluster_id (str): The compute cluster ID to use for the model. + nodepool_id (str): The nodepool ID to use for the model. + deployment_id (str): The deployment ID to use for the model. + inference_params (dict): The inference params to override. + output_config (dict): The output config to override. + """ + + if not os.path.isfile(filepath): + raise UserError('Invalid filepath.') + + # TODO check if the file is streamable already + + # Convert the video file to a streamable format + # TODO this conversion can offset the start time by a little bit; we should account for this + # by getting the original start time ffprobe and either sending that to the model so it can adjust + # with the ts of the first frame (too fragile to do all of this adjustment in the client input stream) + # or by adjusting the timestamps in the output stream + stream = video_utils.convert_to_streamable(filepath) + + # TODO accumulate reads to fill the chunk size + chunk_size = 1024 * 1024 # 1 MB + chunk_iterator = iter(lambda: stream.read(chunk_size), b'') + + return self.stream_by_bytes(chunk_iterator, input_type, compute_cluster_id, nodepool_id, + deployment_id, user_id, inference_params, output_config) + + def _get_model_info_for_inference(self, inference_params: Dict = {}, + output_config: Dict = {}) -> None: + """Gets the model_info with modified inference params and output config. Args: inference_params (dict): The inference params to override. @@ -1182,13 +1229,12 @@ def _override_model_version(self, inference_params: Dict = {}, output_config: Di select_concepts (list[Concept]): The concepts to select. sample_ms (int): The number of milliseconds to sample. """ - params = Struct() - if inference_params is not None: - params.update(inference_params) - - self.model_info.model_version.output_info.CopyFrom( - resources_pb2.OutputInfo( - output_config=resources_pb2.OutputConfig(**output_config), params=params)) + model_info = resources_pb2.Model() + model_info.CopyFrom(self.model_info) + model_info.model_version.output_info.params = inference_params + model_info.model_version.output_info.output_config.CopyFrom( + resources_pb2.OutputConfig(**output_config)) + return model_info def _list_concepts(self) -> List[str]: """Lists all the concepts for the model type. diff --git a/clarifai/runners/models/base_typed_model.py b/clarifai/runners/models/base_typed_model.py index 2809d0c0..aff81ef3 100644 --- a/clarifai/runners/models/base_typed_model.py +++ b/clarifai/runners/models/base_typed_model.py @@ -6,6 +6,9 @@ from clarifai_grpc.grpc.api.service_pb2 import PostModelOutputsRequest from google.protobuf import json_format +from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded +from clarifai.utils.stream_utils import readahead + from ..utils.data_handler import InputDataHandler, OutputDataHandler from .model_class import ModelClass @@ -46,12 +49,16 @@ def convert_output_to_proto(self, outputs: list): def predict_wrapper( self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse: + if self.download_request_urls: + ensure_urls_downloaded(request) list_dict_input, inference_params = self.parse_input_request(request) outputs = self.predict(list_dict_input, inference_parameters=inference_params) return self.convert_output_to_proto(outputs) def generate_wrapper( self, request: PostModelOutputsRequest) -> Iterator[service_pb2.MultiOutputResponse]: + if self.download_request_urls: + ensure_urls_downloaded(request) list_dict_input, inference_params = self.parse_input_request(request) outputs = self.generate(list_dict_input, inference_parameters=inference_params) for output in outputs: @@ -64,11 +71,13 @@ def _preprocess_stream( input_data, _ = self.parse_input_request(req) yield input_data - def stream_wrapper(self, request: Iterator[PostModelOutputsRequest] + def stream_wrapper(self, request_iterator: Iterator[PostModelOutputsRequest] ) -> Iterator[service_pb2.MultiOutputResponse]: - first_request = next(request) + if self.download_request_urls: + request_iterator = readahead(map(ensure_urls_downloaded, request_iterator)) + first_request = next(request_iterator) _, inference_params = self.parse_input_request(first_request) - request_iterator = itertools.chain([first_request], request) + request_iterator = itertools.chain([first_request], request_iterator) outputs = self.stream(self._preprocess_stream(request_iterator), inference_params) for output in outputs: yield self.convert_output_to_proto(output) diff --git a/clarifai/runners/models/model_class.py b/clarifai/runners/models/model_class.py index 5b342ba2..bf2e234a 100644 --- a/clarifai/runners/models/model_class.py +++ b/clarifai/runners/models/model_class.py @@ -3,23 +3,41 @@ from clarifai_grpc.grpc.api import service_pb2 +from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded +from clarifai.utils.stream_utils import readahead + class ModelClass(ABC): + download_request_urls = True + def predict_wrapper( self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse: """This method is used for input/output proto data conversion""" + # Download any urls that are not already bytes. + if self.download_request_urls: + ensure_urls_downloaded(request) + return self.predict(request) def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest ) -> Iterator[service_pb2.MultiOutputResponse]: """This method is used for input/output proto data conversion and yield outcome""" + # Download any urls that are not already bytes. + if self.download_request_urls: + ensure_urls_downloaded(request) + return self.generate(request) - def stream_wrapper(self, request: service_pb2.PostModelOutputsRequest + def stream_wrapper(self, request_stream: Iterator[service_pb2.PostModelOutputsRequest] ) -> Iterator[service_pb2.MultiOutputResponse]: """This method is used for input/output proto data conversion and yield outcome""" - return self.stream(request) + + # Download any urls that are not already bytes. + if self.download_request_urls: + request_stream = readahead(map(ensure_urls_downloaded, request_stream)) + + return self.stream(request_stream) @abstractmethod def load_model(self): diff --git a/clarifai/runners/models/model_runner.py b/clarifai/runners/models/model_runner.py index a24cba4d..9e8a495a 100644 --- a/clarifai/runners/models/model_runner.py +++ b/clarifai/runners/models/model_runner.py @@ -5,7 +5,6 @@ from clarifai_protocol import BaseRunner from clarifai_protocol.utils.health import HealthProbeRequestHandler -from ..utils.url_fetcher import ensure_urls_downloaded from .model_class import ModelClass @@ -79,7 +78,6 @@ def runner_item_predict(self, if not runner_item.HasField('post_model_outputs_request'): raise Exception("Unexpected work item type: {}".format(runner_item)) request = runner_item.post_model_outputs_request - ensure_urls_downloaded(request) resp = self.model.predict_wrapper(request) successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs] @@ -109,7 +107,6 @@ def runner_item_generate( if not runner_item.HasField('post_model_outputs_request'): raise Exception("Unexpected work item type: {}".format(runner_item)) request = runner_item.post_model_outputs_request - ensure_urls_downloaded(request) for resp in self.model.generate_wrapper(request): successes = [] @@ -169,5 +166,4 @@ def pmo_iterator(runner_item_iterator): for runner_item in runner_item_iterator: if not runner_item.HasField('post_model_outputs_request'): raise Exception("Unexpected work item type: {}".format(runner_item)) - ensure_urls_downloaded(runner_item.post_model_outputs_request) yield runner_item.post_model_outputs_request diff --git a/clarifai/runners/models/model_servicer.py b/clarifai/runners/models/model_servicer.py index f241271c..d6252517 100644 --- a/clarifai/runners/models/model_servicer.py +++ b/clarifai/runners/models/model_servicer.py @@ -1,11 +1,8 @@ -from itertools import tee from typing import Iterator from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2 -from ..utils.url_fetcher import ensure_urls_downloaded - class ModelServicer(service_pb2_grpc.V2Servicer): """ @@ -27,9 +24,6 @@ def PostModelOutputs(self, request: service_pb2.PostModelOutputsRequest, returns an output. """ - # Download any urls that are not already bytes. - ensure_urls_downloaded(request) - try: return self.model.predict_wrapper(request) except Exception as e: @@ -46,9 +40,6 @@ def GenerateModelOutputs(self, request: service_pb2.PostModelOutputsRequest, This is the method that will be called when the servicer is run. It takes in an input and returns an output. """ - # Download any urls that are not already bytes. - ensure_urls_downloaded(request) - try: return self.model.generate_wrapper(request) except Exception as e: @@ -66,15 +57,8 @@ def StreamModelOutputs(self, This is the method that will be called when the servicer is run. It takes in an input and returns an output. """ - # Duplicate the iterator - request, request_copy = tee(request) - - # Download any urls that are not already bytes. - for req in request: - ensure_urls_downloaded(req) - try: - return self.model.stream_wrapper(request_copy) + return self.model_class.stream_wrapper(request) except Exception as e: yield service_pb2.MultiOutputResponse(status=status_pb2.Status( code=status_code_pb2.MODEL_PREDICTION_FAILED, diff --git a/clarifai/runners/utils/url_fetcher.py b/clarifai/runners/utils/url_fetcher.py index 081d298b..8531d283 100644 --- a/clarifai/runners/utils/url_fetcher.py +++ b/clarifai/runners/utils/url_fetcher.py @@ -1,8 +1,10 @@ import concurrent.futures +from typing import Iterable import fsspec from clarifai.utils.logging import logger +from clarifai.utils.stream_utils import MB def download_input(input): @@ -47,3 +49,13 @@ def ensure_urls_downloaded(request, max_threads=128): future.result() except Exception as e: logger.exception(f"Error downloading input: {e}") + return request + + +def stream_url(url: str, chunk_size: int = 1 * MB) -> Iterable[bytes]: + """ + Opens a stream of byte chunks from a URL. + """ + # block_size=0 means that the file is streamed + with fsspec.open(url, 'rb', block_size=0) as f: + yield from iter(lambda: f.read(chunk_size), b'') diff --git a/clarifai/utils/misc.py b/clarifai/utils/misc.py index 55d2ff0b..16848708 100644 --- a/clarifai/utils/misc.py +++ b/clarifai/utils/misc.py @@ -1,3 +1,4 @@ +import importlib import os import re import uuid @@ -18,6 +19,29 @@ def status_is_retryable(status_code: int) -> bool: return status_code in RETRYABLE_CODES +def optional_import(module_name: str, pip_package: str = None): + """Import a module if it exists. + Otherwise, return an object that will raise an error when accessed. + """ + try: + return importlib.import_module(module_name) + except ImportError: + return _MissingModule(module_name, pip_package=pip_package) + + +class _MissingModule: + """Object that raises an error when accessed.""" + + def __init__(self, module_name, pip_package=None): + self.module_name = module_name + self.message = f"Module `{module_name}` is not installed." + if pip_package: + self.message += f" Please add `{pip_package}` to your requirements.txt file." + + def __getattr__(self, name): + raise ImportError(self.message) + + class Chunker: """Split an input sequence into small chunks.""" diff --git a/clarifai/utils/stream_utils.py b/clarifai/utils/stream_utils.py new file mode 100644 index 00000000..6b51b582 --- /dev/null +++ b/clarifai/utils/stream_utils.py @@ -0,0 +1,239 @@ +import io +import queue + +import threading +from concurrent.futures import ThreadPoolExecutor + +MB = 1024 * 1024 + + +class StreamingChunksReader(io.RawIOBase): + ''' + A buffered reader that reads data from an iterator yielding chunks of bytes, used + to provide file-like access to a streaming data source. + + :param chunk_iterator: An iterator that yields chunks of data (bytes) + ''' + + def __init__(self, chunk_iterator): + """ + Args: + chunk_iterator (iterator): An iterator that yields chunks of bytes. + """ + self._chunk_iterator = chunk_iterator + self.response = None + self.buffer = b'' + self.b_pos = 0 + self._eof = False + + def readable(self): + return True + + def readinto(self, output_buf): + if self._eof: + return 0 + + try: + # load next chunk if necessary + if self.b_pos == len(self.buffer): + self.buffer = next(self._chunk_iterator) + self.b_pos = 0 + + # copy data to output buffer + n = min(len(output_buf), len(self.buffer) - self.b_pos) + assert n > 0 + + output_buf[:n] = self.buffer[self.b_pos:self.b_pos + n] + + # advance positions + self.b_pos += n + assert self.b_pos <= len(self.buffer) + + return n + + except StopIteration: + self._eof = True + return 0 + + +class SeekableStreamingChunksReader(io.RawIOBase): + """ + A buffered reader that reads data from an iterator yielding chunks of bytes, used + to provide file-like access to a streaming data source. + + This class allows supports limited seeking to positions within the stream, by buffering + buffering chunks internally and supporting basic seek operations within the buffer. + """ + + def __init__(self, chunk_iterator, buffer_size=100 * MB): + """ + Args: + chunk_iterator (iterator): An iterator that yields chunks of bytes. + buffer_size (int): Maximum buffer size in bytes before old chunks are discarded. + """ + self._chunk_iterator = chunk_iterator + self.buffer_size = buffer_size + self.buffer_vec = [] + self.file_pos = 0 + self.vec_pos = 0 + self.b_pos = 0 + self._eof = False + + #### read() methods + + def readable(self): + return True + + def readinto(self, output_buf): + """ + Read data into the given buffer. + + Args: + output_buf (bytearray): Buffer to read data into. + + Returns: + int: Number of bytes read. + """ + if self._eof: + return 0 + + assert self.vec_pos <= len(self.buffer_vec) + + try: + # load next chunk if necessary + if self.vec_pos == len(self.buffer_vec): + self._load_next_chunk() + + # copy data from buffer_vec to output buffer + n = min(len(output_buf), len(self.buffer_vec[self.vec_pos]) - self.b_pos) + assert n > 0 + + output_buf[:n] = self.buffer_vec[self.vec_pos][self.b_pos:self.b_pos + n] + + # advance positions + self.file_pos += n + self.b_pos += n + assert self.b_pos <= len(self.buffer_vec[self.vec_pos]) + if self.b_pos == len(self.buffer_vec[self.vec_pos]): + self.vec_pos += 1 + self.b_pos = 0 + return n + except StopIteration: + self._eof = True + return 0 + + def _load_next_chunk(self, check_bounds=True): + self.buffer_vec.append(next(self._chunk_iterator)) + total = sum(len(chunk) for chunk in self.buffer_vec) + while total > self.buffer_size and len(self.buffer_vec) > 1: # keep at least the last chunk + chunk = self.buffer_vec.pop(0) + total -= len(chunk) + self.vec_pos -= 1 + if check_bounds: + assert self.vec_pos >= 0, 'current position fell outside the buffer' + + #### seek() methods (experimental) + + def seekable(self): + return True + + def tell(self): + return self.file_pos + + def seek(self, offset, whence=io.SEEK_SET): + """ + Seek to a new position in the buffered stream. + + Args: + offset (int): The offset to seek to. + whence (int): The reference position (SEEK_SET, SEEK_CUR). + SEEK_END is not supported. + + Returns: + int: The new file position. + + Raises: + ValueError: If an invalid `whence` value is provided. + IOError: If seeking before the start of the buffer. + """ + if whence == io.SEEK_SET: + seek_pos = offset + elif whence == io.SEEK_CUR: + seek_pos = self.file_pos + offset + elif whence == io.SEEK_END: + raise ValueError('SEEK_END is not supported') + else: + raise ValueError(f"Invalid whence: {whence}") + + # set positions to start of buffer vec to begin seeking + self.file_pos -= self.b_pos + self.b_pos = 0 + while self.vec_pos > 0: + self.vec_pos -= 1 + self.file_pos -= len(self.buffer_vec[self.vec_pos]) + + # check if still seeking backwards off the start of the buffer + if seek_pos < self.file_pos: + raise IOError('seek before start of buffer') + + # seek forwards to desired position + while self.file_pos < seek_pos: + if self.vec_pos == len(self.buffer_vec): + self._load_next_chunk() + n = len(self.buffer_vec[self.vec_pos]) + if self.file_pos + n > seek_pos: + self.b_pos = seek_pos - self.file_pos + self.file_pos = seek_pos + break + self.file_pos += n + self.vec_pos += 1 + + # unset EOF flag + self._eof = False + + return self.file_pos + + +def readahead(iterator, n=1, daemon=True): + """ + Iterator wrapper that reads ahead from the underlying iterator, using a background thread. + + :Args: + iterator (iterator): The iterator to read from. + n (int): The maximum number of items to read ahead. + daemon (bool): Whether the background thread should be a daemon thread. + """ + q = queue.Queue(maxsize=n) + _sentinel = object() + + def _read(): + for x in iterator: + q.put(x) + q.put(_sentinel) + + t = threading.Thread(target=_read, daemon=daemon) + t.start() + while True: + x = q.get() + if x is _sentinel: + break + yield x + + +def map(f, iterator, parallel=1): + ''' + Apply a function to each item in an iterator, optionally using multiple threads. + Similar to the built-in `map` function, but with support for parallel execution. + ''' + if parallel < 1: + return map(f, iterator) + with ThreadPoolExecutor(max_workers=parallel) as executor: + futures = [] + for i in range(parallel): + futures.append(executor.submit(f, next(iterator))) + for r in iterator: + res = futures.pop(0).result() + futures.append(executor.submit(f, r)) # start computing next result before yielding this one + yield res + for f in futures: + yield f.result() diff --git a/clarifai/utils/video_utils.py b/clarifai/utils/video_utils.py new file mode 100644 index 00000000..3e39ea71 --- /dev/null +++ b/clarifai/utils/video_utils.py @@ -0,0 +1,110 @@ +import io +import os +import tempfile +import threading + +import requests + +from clarifai.utils import stream_utils +from clarifai.utils.misc import optional_import + +av = optional_import("av", pip_package="av") + + +def stream_frames_from_url(url, download_ok=True): + """ + Streams a video at the specified resolution using PyAV. + + :param url: The video URL + :param download_ok: Whether to download the video if the URL is not a stream + """ + protocol = url.split('://', 1)[0] + if protocol == 'rtsp': + # stream from RTSP and send to PyAV + container = av.open(url) + elif protocol in ('http', 'https'): + if not download_ok: + raise ValueError('Download not allowed for URL scheme') + # download the video to the temporary file + # TODO: download just enough to get the file header and stream to pyav if possible, + # otherwise download the whole file + # e.g. if linking to a streamable file format like mpegts (not mp4) + file = tempfile.NamedTemporaryFile(delete=True) + download_file(url, file.name) + container = av.open(file.name) + else: + # TODO others: s3, etc. + raise ValueError('Unsupported URL scheme') + + # Decode video frames + yield from container.decode(video=0) + + +def download_file(url, file_name): + response = requests.get(url, stream=True) + response.raise_for_status() + with open(file_name, 'wb') as f: + for chunk in response.iter_content(chunk_size=1024): + f.write(chunk) + + +def stream_frames_from_file(filename): + """ + Streams a video from a file using PyAV. + + :param filename: The video file path + """ + container = av.open(filename) + yield from container.decode(video=0) + + +def stream_frames_from_bytes(bytes_iterator): + """ + Streams a video from a sequence of chunked byte strings of a streamable video + container format. + + :param bytes_iterator: An iterator that yields byte chunks with the video data + """ + buffer = stream_utils.StreamingChunksReader(bytes_iterator) + reader = io.BufferedReader(buffer) + container = av.open(reader) + yield from container.decode(video=0) + + +def convert_to_streamable(filepath): + return recontain(filepath, "mpegts", {"muxpreload": "0", "muxdelay": "0"}) + + +def recontain(input, format, options={}): + # pyav-only implementation of "ffmpeg -i filepath -f mpegts -muxpreload 0 -muxdelay 0 pipe:" + read_pipe_fd, write_pipe_fd = os.pipe() + read_pipe = os.fdopen(read_pipe_fd, "rb") + write_pipe = os.fdopen(write_pipe_fd, "wb") + + def _run_av(): + input_container = output_container = None + try: + # open input and output containers, using mpegts as output format + input_container = av.open(input, options=options) + output_container = av.open(write_pipe, mode="w", format=format) + + # Copy streams directly without re-encoding + for stream in input_container.streams: + output_container.add_stream_from_template(stream) + + # Read packets from input and write them to output + for packet in input_container.demux(): + if not packet.size: + break + output_container.mux(packet) + + finally: + if output_container: + output_container.close() + if input_container: + input_container.close() + + t = threading.Thread(target=_run_av) + t.start() + + return read_pipe diff --git a/requirements.txt b/requirements.txt index cd4c3759..08407ade 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ tabulate>=0.9.0 fsspec>=2024.6.1 click>=8.1.7 requests>=2.32.3 +aiohttp>=3.10 diff --git a/tests/runners/test_stream_utils.py b/tests/runners/test_stream_utils.py new file mode 100644 index 00000000..fc136cfb --- /dev/null +++ b/tests/runners/test_stream_utils.py @@ -0,0 +1,77 @@ +import io +import unittest + +from clarifai.utils.stream_utils import SeekableStreamingChunksReader, StreamingChunksReader + + +class TestStreamingChunksReader(unittest.TestCase): + + def setUp(self): + self.chunks = [b'hello', b'world', b'12345'] + #self.reader = BufferStream(iter(self.chunks), buffer_size=10) + self.reader = StreamingChunksReader(iter(self.chunks)) + + def test_read(self): + buffer = bytearray(5) + self.assertEqual(self.reader.readinto(buffer), 5) + self.assertEqual(buffer, b'hello') + + def test_read_file(self): + self.assertEqual(self.reader.read(5), b'hello') + + def test_read_partial_chunk(self): + """Test reading fewer bytes than a chunk contains, across multiple reads.""" + buffer = bytearray(3) + self.assertEqual(self.reader.readinto(buffer), 3) + self.assertEqual(buffer, b'hel') + self.assertEqual(self.reader.readinto(buffer), 2) + self.assertEqual(buffer[:2], b'lo') + self.assertEqual(self.reader.readinto(buffer), 3) + self.assertEqual(buffer, b'wor') + + def test_large_chunk(self): + """Test handling a chunk larger than the buffer size.""" + large_chunk = b'a' * 20 + reader = StreamingChunksReader(iter([large_chunk])) + buffer = bytearray(10) + self.assertEqual(reader.readinto(buffer), 10) + self.assertEqual(buffer, b'a' * 10) + self.assertEqual(reader.readinto(buffer), 10) + self.assertEqual(buffer, b'a' * 10) + + +class TestSeekableStreamingChunksReader(TestStreamingChunksReader): + + def setUp(self): + self.chunks = [b'hello', b'world', b'12345'] + self.reader = SeekableStreamingChunksReader(iter(self.chunks), buffer_size=10) + + def test_interleaved_read_and_seek(self): + """Test alternating read and seek operations.""" + buffer = bytearray(5) + self.reader.readinto(buffer) + self.assertEqual(buffer, b'hello') + buffer[:] = b'xxxxx' + self.reader.seek(0) + self.assertEqual(self.reader.readinto(buffer), 5) + self.assertEqual(buffer, b'hello') + self.reader.seek(7) + n = self.reader.readinto(buffer) + assert 1 <= n <= len(buffer) + self.assertEqual(buffer[:n], b''.join(self.chunks)[7:7 + n]) + + def test_seek_and_tell(self): + """Test seeking to a position and confirming it with tell().""" + self.reader.seek(5) + self.assertEqual(self.reader.tell(), 5) + self.reader.seek(-2, io.SEEK_CUR) + self.assertEqual(self.reader.tell(), 3) + + def test_seek_out_of_bounds(self): + """Test seeking to a negative position, which should raise an IOError.""" + with self.assertRaises(IOError): + self.reader.seek(-1) + + +if __name__ == '__main__': + unittest.main()