Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/instructlab/sdg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"SamplePopulatorBlock",
"SelectorBlock",
"SetToMajorityValueBlock",
"SimilarityFilterBlock",
"FULL_PIPELINES_PACKAGE",
"SIMPLE_PIPELINES_PACKAGE",
"LLAMA_PIPELINES_PKG",
Expand All @@ -37,6 +38,7 @@
from .blocks.block import Block, BlockConfigParserError
from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError
from .blocks.iterblock import IterBlock
from .blocks.similarityfilterblock import SimilarityFilterBlock
from .blocks.llmblock import (
ConditionalLLMBlock,
LLMBlock,
Expand Down
109 changes: 109 additions & 0 deletions src/instructlab/sdg/blocks/similarityfilterblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from difflib import SequenceMatcher
import logging

# Third Party
import pandas as pd
from datasets import Dataset

# Local
from ..registry import BlockRegistry
from ..utils.pandas import dataset_from_pandas_dataframe
from .block import Block

logger = logging.getLogger(__name__)


def _similarity(a: str, b: str) -> float:
"""Compute similarity ratio between two strings."""
if not a or not b:
return 0.0
return SequenceMatcher(None, a, b).ratio()


def _deduplicate_group(group, col, threshold):
"""Remove near-duplicate rows within a single group.

Returns a list of integer indices to keep.
"""
kept_indices = []
kept_texts = []

for idx, row in group.iterrows():
text = str(row[col])
is_duplicate = any(
_similarity(text, kept) > threshold for kept in kept_texts
)
if not is_duplicate:
kept_indices.append(idx)
kept_texts.append(text)

return kept_indices


# This is part of the public API.
@BlockRegistry.register("SimilarityFilterBlock")
class SimilarityFilterBlock(Block):
def __init__(
self,
ctx,
pipe,
block_name,
filter_column,
threshold=0.85,
group_by=None,
) -> None:
"""
Initializes a new instance of the SimilarityFilterBlock class.

Parameters:
- ctx (PipelineContext): A PipelineContext object containing runtime parameters.
- pipe (Pipeline): The Pipeline containing this block in its chain.
- block_name (str): An identifier for this block.
- filter_column (str): The column containing text to compare for similarity.
- threshold (float): Similarity ratio (0.0 to 1.0). Rows with similarity
above this value to any previously kept row are dropped. Default 0.85.
- group_by (str, optional): Column to group by before deduplication.
If set, similarity is only compared within each group. Default None.
"""
super().__init__(ctx, pipe, block_name)
self.filter_column = filter_column
self.threshold = threshold
self.group_by = group_by

def generate(self, samples) -> Dataset:
if len(samples) == 0:
return samples

df = samples.to_pandas()
original_len = len(df)

if self.group_by and self.group_by in df.columns:
groups = []
for _, group in df.groupby(self.group_by):
kept = _deduplicate_group(group, self.filter_column, self.threshold)
groups.append(group.loc[kept])
result = (
pd.concat(groups, ignore_index=True)
if groups
else df.iloc[:0]
)
else:
kept = _deduplicate_group(df, self.filter_column, self.threshold)
result = df.loc[kept]

removed = original_len - len(result)
if removed > 0:
logger.info(
"SimilarityFilterBlock '%s': removed %d near-duplicates "
"(threshold=%.2f), %d → %d rows",
self.block_name,
removed,
self.threshold,
original_len,
len(result),
)

return dataset_from_pandas_dataframe(result)
97 changes: 97 additions & 0 deletions tests/unit/test_similarityfilterblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from unittest.mock import MagicMock
import unittest

# Third Party
from datasets import Dataset

# First Party
from instructlab.sdg import SimilarityFilterBlock


class TestSimilarityFilterBlock(unittest.TestCase):
def setUp(self):
self.ctx = MagicMock()
self.ctx.dataset_num_procs = 1
self.pipe = MagicMock()

def _make_block(self, filter_column="text", threshold=0.85, group_by=None):
return SimilarityFilterBlock(
self.ctx,
self.pipe,
"test_similarity_filter",
filter_column=filter_column,
threshold=threshold,
group_by=group_by,
)

def test_keeps_unique_rows(self):
block = self._make_block()
ds = Dataset.from_dict(
{"text": ["alpha bravo charlie", "delta echo foxtrot", "golf hotel india"]}
)
result = block.generate(ds)
self.assertEqual(len(result), 3)

def test_removes_exact_duplicates(self):
block = self._make_block(threshold=0.8)
ds = Dataset.from_dict(
{"text": ["hello world", "hello world", "hello world"]}
)
result = block.generate(ds)
self.assertEqual(len(result), 1)

def test_removes_near_duplicates(self):
block = self._make_block(threshold=0.7)
ds = Dataset.from_dict(
{
"text": [
"What is photosynthesis and how does it work?",
"What is photosynthesis and how does it function?",
"Explain the process of sourdough bread making.",
]
}
)
result = block.generate(ds)
self.assertEqual(len(result), 2)

def test_group_by_isolates_groups(self):
block = self._make_block(threshold=0.8, group_by="doc_id")
ds = Dataset.from_dict(
{
"text": ["same text here", "same text here"],
"doc_id": ["doc_a", "doc_b"],
}
)
result = block.generate(ds)
self.assertEqual(len(result), 2)

def test_group_by_deduplicates_within_group(self):
block = self._make_block(threshold=0.8, group_by="doc_id")
ds = Dataset.from_dict(
{
"text": ["same text here", "same text here"],
"doc_id": ["doc_a", "doc_a"],
}
)
result = block.generate(ds)
self.assertEqual(len(result), 1)

def test_empty_dataset(self):
block = self._make_block()
ds = Dataset.from_dict({"text": []})
result = block.generate(ds)
self.assertEqual(len(result), 0)

def test_low_threshold_more_aggressive(self):
texts = [
"What is photosynthesis?",
"What is the process of photosynthesis?",
"Explain sourdough bread.",
]
strict = self._make_block(threshold=0.5)
lenient = self._make_block(threshold=0.95)
ds = Dataset.from_dict({"text": texts})
self.assertLessEqual(len(strict.generate(ds)), len(lenient.generate(ds)))
Loading