Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/advanced_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def main():
patience_early_stopping=5,
num_workers=0,
trainer_params={"deterministic": True},
raw_labels=False, # no encoding needed, labels are already integers
)

classifier.train(
Expand Down Expand Up @@ -162,6 +163,8 @@ def main():
patience_early_stopping=7,
num_workers=0,
trainer_params=advanced_trainer_params,
raw_labels=False, # no encoding needed, labels are already integers

)

advanced_classifier.train(
Expand Down Expand Up @@ -196,6 +199,8 @@ def main():
patience_early_stopping=3,
num_workers=0, # No multiprocessing for CPU
trainer_params={"deterministic": True, "accelerator": "cpu"},
raw_labels=False, # no encoding needed, labels are already integers

)

cpu_classifier.train(
Expand Down Expand Up @@ -225,7 +230,6 @@ def main():
"max_epochs": 25,
"enable_progress_bar": True,
"log_every_n_steps": 1,
"check_val_every_n_epoch": 2, # Validate every 2 epochs
"enable_checkpointing": True,
"enable_model_summary": True,
"deterministic": True,
Expand All @@ -238,6 +242,7 @@ def main():
patience_early_stopping=8,
num_workers=0,
trainer_params=custom_trainer_params,
raw_labels=False, # no encoding needed, labels are already integers
)

custom_classifier.train(
Expand Down
1 change: 1 addition & 0 deletions examples/basic_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def main():
lr=1e-3,
patience_early_stopping=5,
num_workers=0, # Use 0 for simple examples to avoid multiprocessing issues
raw_labels=False # no encoding needed, labels are already integers
)
classifier.train(
X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True
Expand Down
1 change: 1 addition & 0 deletions examples/multiclass_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def main():
patience_early_stopping=7,
num_workers=0,
trainer_params={"deterministic": True},
raw_labels=False, # no encoding needed, labels are already integers
)
classifier.train(
X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True
Expand Down
2 changes: 1 addition & 1 deletion examples/simple_explainability_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def main():
patience_early_stopping=5,
num_workers=0,
trainer_params={"deterministic": True},
raw_labels=False # no encoding needed, labels are already integers
)
classifier.train(
X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val, verbose=True
Expand Down Expand Up @@ -279,7 +280,6 @@ def main():
# Extract attributions and mapping info
attributions = result["attributions"][0][0] # shape: (seq_len,)
offset_mapping = result["offset_mapping"][0] # List of (start, end) tuples
word_ids = result["word_ids"][0] # List of word IDs for each token

# Map token-level attributions to character-level (for ASCII visualization)
char_attributions = map_attributions_to_char(
Expand Down
4 changes: 3 additions & 1 deletion examples/using_additional_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple
patience_early_stopping=3,
num_workers=0,
trainer_params={"enable_progress_bar": True, "deterministic": True},
raw_labels=False, # no encoding needed, labels are already integers
raw_categorical_inputs=False, # no encoding needed, categorical inputs are already integers
)

# Create and build model
Expand All @@ -172,7 +174,7 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple
if use_categorical:
print(" ✅ Running validation for text-with-categorical-variables model...")
try:
result = classifier.predict(X_test)
result = classifier.predict(X_test, raw_categorical_inputs=False)
predictions = result["prediction"].squeeze().numpy()
test_accuracy = (predictions == y_test).mean()
print(f" Test accuracy: {test_accuracy:.3f}")
Expand Down
5 changes: 3 additions & 2 deletions torchTextClassifiers/torchTextClassifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,10 @@ def train(
if y_val is not None:
assert X_val is not None, "X_val must be provided if y_val is provided."

X_val: Optional[Dict[str, Any]] = None
X_val_checked: Optional[Dict[str, Any]] = None
if X_val is not None and y_val is not None:
X_val, y_val = self._check_XY(X_val, y_val)
X_val_checked, y_val = self._check_XY(X_val, y_val, training_config.raw_categorical_inputs, training_config.raw_labels)
X_val = X_val_checked

if (
(X_train["categorical_variables"] is not None)
Expand Down