diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 8374a2a8..a39b31cf 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -26,6 +26,7 @@ "SamplePopulatorBlock", "SelectorBlock", "SetToMajorityValueBlock", + "SimilarityFilterBlock", "FULL_PIPELINES_PACKAGE", "SIMPLE_PIPELINES_PACKAGE", "LLAMA_PIPELINES_PKG", @@ -37,6 +38,7 @@ from .blocks.block import Block, BlockConfigParserError from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError from .blocks.iterblock import IterBlock +from .blocks.similarityfilterblock import SimilarityFilterBlock from .blocks.llmblock import ( ConditionalLLMBlock, LLMBlock, diff --git a/src/instructlab/sdg/blocks/similarityfilterblock.py b/src/instructlab/sdg/blocks/similarityfilterblock.py new file mode 100644 index 00000000..bfbff688 --- /dev/null +++ b/src/instructlab/sdg/blocks/similarityfilterblock.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from difflib import SequenceMatcher +import logging + +# Third Party +import pandas as pd +from datasets import Dataset + +# Local +from ..registry import BlockRegistry +from ..utils.pandas import dataset_from_pandas_dataframe +from .block import Block + +logger = logging.getLogger(__name__) + + +def _similarity(a: str, b: str) -> float: + """Compute similarity ratio between two strings.""" + if not a or not b: + return 0.0 + return SequenceMatcher(None, a, b).ratio() + + +def _deduplicate_group(group, col, threshold): + """Remove near-duplicate rows within a single group. + + Returns a list of integer indices to keep. + """ + kept_indices = [] + kept_texts = [] + + for idx, row in group.iterrows(): + text = str(row[col]) + is_duplicate = any( + _similarity(text, kept) > threshold for kept in kept_texts + ) + if not is_duplicate: + kept_indices.append(idx) + kept_texts.append(text) + + return kept_indices + + +# This is part of the public API. +@BlockRegistry.register("SimilarityFilterBlock") +class SimilarityFilterBlock(Block): + def __init__( + self, + ctx, + pipe, + block_name, + filter_column, + threshold=0.85, + group_by=None, + ) -> None: + """ + Initializes a new instance of the SimilarityFilterBlock class. + + Parameters: + - ctx (PipelineContext): A PipelineContext object containing runtime parameters. + - pipe (Pipeline): The Pipeline containing this block in its chain. + - block_name (str): An identifier for this block. + - filter_column (str): The column containing text to compare for similarity. + - threshold (float): Similarity ratio (0.0 to 1.0). Rows with similarity + above this value to any previously kept row are dropped. Default 0.85. + - group_by (str, optional): Column to group by before deduplication. + If set, similarity is only compared within each group. Default None. + """ + super().__init__(ctx, pipe, block_name) + self.filter_column = filter_column + self.threshold = threshold + self.group_by = group_by + + def generate(self, samples) -> Dataset: + if len(samples) == 0: + return samples + + df = samples.to_pandas() + original_len = len(df) + + if self.group_by and self.group_by in df.columns: + groups = [] + for _, group in df.groupby(self.group_by): + kept = _deduplicate_group(group, self.filter_column, self.threshold) + groups.append(group.loc[kept]) + result = ( + pd.concat(groups, ignore_index=True) + if groups + else df.iloc[:0] + ) + else: + kept = _deduplicate_group(df, self.filter_column, self.threshold) + result = df.loc[kept] + + removed = original_len - len(result) + if removed > 0: + logger.info( + "SimilarityFilterBlock '%s': removed %d near-duplicates " + "(threshold=%.2f), %d → %d rows", + self.block_name, + removed, + self.threshold, + original_len, + len(result), + ) + + return dataset_from_pandas_dataframe(result) diff --git a/tests/unit/test_similarityfilterblock.py b/tests/unit/test_similarityfilterblock.py new file mode 100644 index 00000000..99a7711a --- /dev/null +++ b/tests/unit/test_similarityfilterblock.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from unittest.mock import MagicMock +import unittest + +# Third Party +from datasets import Dataset + +# First Party +from instructlab.sdg import SimilarityFilterBlock + + +class TestSimilarityFilterBlock(unittest.TestCase): + def setUp(self): + self.ctx = MagicMock() + self.ctx.dataset_num_procs = 1 + self.pipe = MagicMock() + + def _make_block(self, filter_column="text", threshold=0.85, group_by=None): + return SimilarityFilterBlock( + self.ctx, + self.pipe, + "test_similarity_filter", + filter_column=filter_column, + threshold=threshold, + group_by=group_by, + ) + + def test_keeps_unique_rows(self): + block = self._make_block() + ds = Dataset.from_dict( + {"text": ["alpha bravo charlie", "delta echo foxtrot", "golf hotel india"]} + ) + result = block.generate(ds) + self.assertEqual(len(result), 3) + + def test_removes_exact_duplicates(self): + block = self._make_block(threshold=0.8) + ds = Dataset.from_dict( + {"text": ["hello world", "hello world", "hello world"]} + ) + result = block.generate(ds) + self.assertEqual(len(result), 1) + + def test_removes_near_duplicates(self): + block = self._make_block(threshold=0.7) + ds = Dataset.from_dict( + { + "text": [ + "What is photosynthesis and how does it work?", + "What is photosynthesis and how does it function?", + "Explain the process of sourdough bread making.", + ] + } + ) + result = block.generate(ds) + self.assertEqual(len(result), 2) + + def test_group_by_isolates_groups(self): + block = self._make_block(threshold=0.8, group_by="doc_id") + ds = Dataset.from_dict( + { + "text": ["same text here", "same text here"], + "doc_id": ["doc_a", "doc_b"], + } + ) + result = block.generate(ds) + self.assertEqual(len(result), 2) + + def test_group_by_deduplicates_within_group(self): + block = self._make_block(threshold=0.8, group_by="doc_id") + ds = Dataset.from_dict( + { + "text": ["same text here", "same text here"], + "doc_id": ["doc_a", "doc_a"], + } + ) + result = block.generate(ds) + self.assertEqual(len(result), 1) + + def test_empty_dataset(self): + block = self._make_block() + ds = Dataset.from_dict({"text": []}) + result = block.generate(ds) + self.assertEqual(len(result), 0) + + def test_low_threshold_more_aggressive(self): + texts = [ + "What is photosynthesis?", + "What is the process of photosynthesis?", + "Explain sourdough bread.", + ] + strict = self._make_block(threshold=0.5) + lenient = self._make_block(threshold=0.95) + ds = Dataset.from_dict({"text": texts}) + self.assertLessEqual(len(strict.generate(ds)), len(lenient.generate(ds)))