forked from apache/flink
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FLINK-13513][ml] Add the Mapper and related classes for later algori…
…thm implementations.
- Loading branch information
1 parent
4bf46c5
commit 005bda9
Showing
7 changed files
with
401 additions
and
0 deletions.
There are no files selected for viewing
83 changes: 83 additions & 0 deletions
83
flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/Mapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.common.mapper; | ||
|
||
import org.apache.flink.ml.api.misc.param.Params; | ||
import org.apache.flink.table.api.TableSchema; | ||
import org.apache.flink.table.types.DataType; | ||
import org.apache.flink.types.Row; | ||
|
||
import java.io.Serializable; | ||
|
||
/** | ||
* Abstract class for mappers. A mapper takes one row as input and transform it into another row. | ||
*/ | ||
public abstract class Mapper implements Serializable { | ||
|
||
/** | ||
* Schema of the input rows. | ||
*/ | ||
private final String[] dataFieldNames; | ||
private final DataType[] dataFieldTypes; | ||
|
||
/** | ||
* Parameters for the Mapper. | ||
* Users can set the params before the Mapper is executed. | ||
*/ | ||
protected final Params params; | ||
|
||
/** | ||
* Construct a Mapper. | ||
* | ||
* @param dataSchema The schema of input rows. | ||
* @param params The parameters for this mapper. | ||
*/ | ||
public Mapper(TableSchema dataSchema, Params params) { | ||
this.dataFieldNames = dataSchema.getFieldNames(); | ||
this.dataFieldTypes = dataSchema.getFieldDataTypes(); | ||
this.params = (null == params) ? new Params() : params.clone(); | ||
} | ||
|
||
/** | ||
* Get the schema of input rows. | ||
* | ||
* @return The schema of input rows. | ||
*/ | ||
protected TableSchema getDataSchema() { | ||
return TableSchema.builder().fields(dataFieldNames, dataFieldTypes).build(); | ||
} | ||
|
||
/** | ||
* Map a row to a new row. | ||
* | ||
* @param row The input row. | ||
* @return A new row. | ||
* @throws Exception This method may throw exceptions. Throwing an exception will cause the operation to fail. | ||
*/ | ||
public abstract Row map(Row row) throws Exception; | ||
|
||
/** | ||
* Get the schema of the output rows of {@link #map(Row)} method. | ||
* | ||
* @return The table schema of the output rows of {@link #map(Row)} method. | ||
*/ | ||
public abstract TableSchema getOutputSchema(); | ||
|
||
} |
45 changes: 45 additions & 0 deletions
45
...ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/MapperAdapter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.common.mapper; | ||
|
||
import org.apache.flink.api.common.functions.MapFunction; | ||
import org.apache.flink.types.Row; | ||
|
||
/** | ||
* A class that helps adapt a {@link Mapper} to a {@link MapFunction} so that the mapper can run in Flink. | ||
*/ | ||
public class MapperAdapter implements MapFunction<Row, Row> { | ||
|
||
private final Mapper mapper; | ||
|
||
/** | ||
* Construct a MapperAdapter with the given mapper. | ||
* | ||
* @param mapper The {@link Mapper} to adapt. | ||
*/ | ||
public MapperAdapter(Mapper mapper) { | ||
this.mapper = mapper; | ||
} | ||
|
||
@Override | ||
public Row map(Row row) throws Exception { | ||
return this.mapper.map(row); | ||
} | ||
} |
72 changes: 72 additions & 0 deletions
72
...k-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/ModelMapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.common.mapper; | ||
|
||
import org.apache.flink.ml.api.misc.param.Params; | ||
import org.apache.flink.table.api.TableSchema; | ||
import org.apache.flink.table.types.DataType; | ||
import org.apache.flink.types.Row; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* An abstract class for {@link Mapper Mappers} with a model. | ||
*/ | ||
public abstract class ModelMapper extends Mapper { | ||
|
||
/** | ||
* Field names of the model rows. | ||
*/ | ||
private final String[] modelFieldNames; | ||
|
||
/** | ||
* Field types of the model rows. | ||
*/ | ||
private final DataType[] modelFieldTypes; | ||
|
||
/** | ||
* Constructs a ModelMapper. | ||
* | ||
* @param modelSchema The schema of the model rows passed to {@link #loadModel(List)}. | ||
* @param dataSchema The schema of the input data rows. | ||
* @param params The parameters of this ModelMapper. | ||
*/ | ||
public ModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { | ||
super(dataSchema, params); | ||
this.modelFieldNames = modelSchema.getFieldNames(); | ||
this.modelFieldTypes = modelSchema.getFieldDataTypes(); | ||
} | ||
|
||
/** | ||
* Get the schema of the model rows that are passed to {@link #loadModel(List)}. | ||
* | ||
* @return The schema of the model rows. | ||
*/ | ||
protected TableSchema getModelSchema() { | ||
return TableSchema.builder().fields(this.modelFieldNames, this.modelFieldTypes).build(); | ||
} | ||
|
||
/** | ||
* Load the model from the list of rows. | ||
* | ||
* @param modelRows The list of rows that containing the model. | ||
*/ | ||
public abstract void loadModel(List<Row> modelRows); | ||
} |
62 changes: 62 additions & 0 deletions
62
...rent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/ModelMapperAdapter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.common.mapper; | ||
|
||
import org.apache.flink.api.common.functions.RichMapFunction; | ||
import org.apache.flink.configuration.Configuration; | ||
import org.apache.flink.ml.common.model.ModelSource; | ||
import org.apache.flink.types.Row; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* A class that adapts a {@link ModelMapper} to a Flink {@link RichMapFunction} so the model can be | ||
* loaded in a Flink job. | ||
* | ||
* <p>This adapter class hold the target {@link ModelMapper} and it's {@link ModelSource}. Upon open(), | ||
* it will load model rows from {@link ModelSource} into {@link ModelMapper}. | ||
*/ | ||
public class ModelMapperAdapter extends RichMapFunction<Row, Row> { | ||
|
||
private final ModelMapper mapper; | ||
private final ModelSource modelSource; | ||
|
||
/** | ||
* Construct a ModelMapperAdapter with the given ModelMapper and ModelSource. | ||
* | ||
* @param mapper The {@link ModelMapper} to adapt. | ||
* @param modelSource The {@link ModelSource} to load the model from. | ||
*/ | ||
public ModelMapperAdapter(ModelMapper mapper, ModelSource modelSource) { | ||
this.mapper = mapper; | ||
this.modelSource = modelSource; | ||
} | ||
|
||
@Override | ||
public void open(Configuration parameters) throws Exception { | ||
List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext()); | ||
this.mapper.loadModel(modelRows); | ||
} | ||
|
||
@Override | ||
public Row map(Row row) throws Exception { | ||
return this.mapper.map(row); | ||
} | ||
} |
50 changes: 50 additions & 0 deletions
50
...k-ml-lib/src/main/java/org/apache/flink/ml/common/model/BroadcastVariableModelSource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.common.model; | ||
|
||
import org.apache.flink.api.common.functions.RuntimeContext; | ||
import org.apache.flink.types.Row; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* A {@link ModelSource} implementation that reads the model from the broadcast variable. | ||
*/ | ||
public class BroadcastVariableModelSource implements ModelSource { | ||
|
||
/** | ||
* The name of the broadcast variable that hosts the model. | ||
*/ | ||
private final String modelVariableName; | ||
|
||
/** | ||
* Construct a BroadcastVariableModelSource. | ||
* | ||
* @param modelVariableName The name of the broadcast variable that hosts a BroadcastVariableModelSource. | ||
*/ | ||
public BroadcastVariableModelSource(String modelVariableName) { | ||
this.modelVariableName = modelVariableName; | ||
} | ||
|
||
@Override | ||
public List<Row> getModelRows(RuntimeContext runtimeContext) { | ||
return runtimeContext.getBroadcastVariable(modelVariableName); | ||
} | ||
} |
39 changes: 39 additions & 0 deletions
39
flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/ModelSource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.flink.ml.common.model; | ||
|
||
import org.apache.flink.api.common.functions.RuntimeContext; | ||
import org.apache.flink.types.Row; | ||
|
||
import java.io.Serializable; | ||
import java.util.List; | ||
|
||
/** | ||
* An interface that load the model from different sources. E.g. broadcast variables, list of rows, etc. | ||
*/ | ||
public interface ModelSource extends Serializable { | ||
|
||
/** | ||
* Get the rows that containing the model. | ||
* | ||
* @return the rows that containing the model. | ||
*/ | ||
List<Row> getModelRows(RuntimeContext runtimeContext); | ||
} |
Oops, something went wrong.