41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
"""Utilties to handle estimator list"""
|
|
|
|
from ..externals import six
|
|
from sklearn.utils.metaestimators import _BaseComposition
|
|
|
|
|
|
class _BaseXComposition(_BaseComposition):
|
|
"""
|
|
parameter handler for list of estimators
|
|
"""
|
|
def _set_params(self, attr, named_attr, **params):
|
|
# Ordered parameter replacement
|
|
# 1. root parameter
|
|
if attr in params:
|
|
setattr(self, attr, params.pop(attr))
|
|
|
|
# 2. single estimator replacement
|
|
items = getattr(self, named_attr)
|
|
names = []
|
|
if items:
|
|
names, estimators = zip(*items)
|
|
estimators = list(estimators)
|
|
for name in list(six.iterkeys(params)):
|
|
if '__' not in name and name in names:
|
|
# replace single estimator and re-build the
|
|
# root estimators list
|
|
for i, est_name in enumerate(names):
|
|
if est_name == name:
|
|
new_val = params.pop(name)
|
|
if new_val is None:
|
|
del estimators[i]
|
|
else:
|
|
estimators[i] = new_val
|
|
break
|
|
# replace the root estimators
|
|
setattr(self, attr, estimators)
|
|
|
|
# 3. estimator parameters and other initialisation arguments
|
|
super(_BaseXComposition, self).set_params(**params)
|
|
return self
|