Skip to content

Commit

Permalink
[FLINK-13513][ml] Add the Mapper and related classes for later algori…
Browse files Browse the repository at this point in the history
…thm implementations.
  • Loading branch information
xuyang1706 authored and becketqin committed Oct 29, 2019
1 parent 4bf46c5 commit 005bda9
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.mapper;

import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;

import java.io.Serializable;

/**
* Abstract class for mappers. A mapper takes one row as input and transform it into another row.
*/
public abstract class Mapper implements Serializable {

/**
* Schema of the input rows.
*/
private final String[] dataFieldNames;
private final DataType[] dataFieldTypes;

/**
* Parameters for the Mapper.
* Users can set the params before the Mapper is executed.
*/
protected final Params params;

/**
* Construct a Mapper.
*
* @param dataSchema The schema of input rows.
* @param params The parameters for this mapper.
*/
public Mapper(TableSchema dataSchema, Params params) {
this.dataFieldNames = dataSchema.getFieldNames();
this.dataFieldTypes = dataSchema.getFieldDataTypes();
this.params = (null == params) ? new Params() : params.clone();
}

/**
* Get the schema of input rows.
*
* @return The schema of input rows.
*/
protected TableSchema getDataSchema() {
return TableSchema.builder().fields(dataFieldNames, dataFieldTypes).build();
}

/**
* Map a row to a new row.
*
* @param row The input row.
* @return A new row.
* @throws Exception This method may throw exceptions. Throwing an exception will cause the operation to fail.
*/
public abstract Row map(Row row) throws Exception;

/**
* Get the schema of the output rows of {@link #map(Row)} method.
*
* @return The table schema of the output rows of {@link #map(Row)} method.
*/
public abstract TableSchema getOutputSchema();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.mapper;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.types.Row;

/**
* A class that helps adapt a {@link Mapper} to a {@link MapFunction} so that the mapper can run in Flink.
*/
public class MapperAdapter implements MapFunction<Row, Row> {

private final Mapper mapper;

/**
* Construct a MapperAdapter with the given mapper.
*
* @param mapper The {@link Mapper} to adapt.
*/
public MapperAdapter(Mapper mapper) {
this.mapper = mapper;
}

@Override
public Row map(Row row) throws Exception {
return this.mapper.map(row);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.mapper;

import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;

import java.util.List;

/**
* An abstract class for {@link Mapper Mappers} with a model.
*/
public abstract class ModelMapper extends Mapper {

/**
* Field names of the model rows.
*/
private final String[] modelFieldNames;

/**
* Field types of the model rows.
*/
private final DataType[] modelFieldTypes;

/**
* Constructs a ModelMapper.
*
* @param modelSchema The schema of the model rows passed to {@link #loadModel(List)}.
* @param dataSchema The schema of the input data rows.
* @param params The parameters of this ModelMapper.
*/
public ModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
super(dataSchema, params);
this.modelFieldNames = modelSchema.getFieldNames();
this.modelFieldTypes = modelSchema.getFieldDataTypes();
}

/**
* Get the schema of the model rows that are passed to {@link #loadModel(List)}.
*
* @return The schema of the model rows.
*/
protected TableSchema getModelSchema() {
return TableSchema.builder().fields(this.modelFieldNames, this.modelFieldTypes).build();
}

/**
* Load the model from the list of rows.
*
* @param modelRows The list of rows that containing the model.
*/
public abstract void loadModel(List<Row> modelRows);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.mapper;

import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.common.model.ModelSource;
import org.apache.flink.types.Row;

import java.util.List;

/**
* A class that adapts a {@link ModelMapper} to a Flink {@link RichMapFunction} so the model can be
* loaded in a Flink job.
*
* <p>This adapter class hold the target {@link ModelMapper} and it's {@link ModelSource}. Upon open(),
* it will load model rows from {@link ModelSource} into {@link ModelMapper}.
*/
public class ModelMapperAdapter extends RichMapFunction<Row, Row> {

private final ModelMapper mapper;
private final ModelSource modelSource;

/**
* Construct a ModelMapperAdapter with the given ModelMapper and ModelSource.
*
* @param mapper The {@link ModelMapper} to adapt.
* @param modelSource The {@link ModelSource} to load the model from.
*/
public ModelMapperAdapter(ModelMapper mapper, ModelSource modelSource) {
this.mapper = mapper;
this.modelSource = modelSource;
}

@Override
public void open(Configuration parameters) throws Exception {
List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
this.mapper.loadModel(modelRows);
}

@Override
public Row map(Row row) throws Exception {
return this.mapper.map(row);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.model;

import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.types.Row;

import java.util.List;

/**
* A {@link ModelSource} implementation that reads the model from the broadcast variable.
*/
public class BroadcastVariableModelSource implements ModelSource {

/**
* The name of the broadcast variable that hosts the model.
*/
private final String modelVariableName;

/**
* Construct a BroadcastVariableModelSource.
*
* @param modelVariableName The name of the broadcast variable that hosts a BroadcastVariableModelSource.
*/
public BroadcastVariableModelSource(String modelVariableName) {
this.modelVariableName = modelVariableName;
}

@Override
public List<Row> getModelRows(RuntimeContext runtimeContext) {
return runtimeContext.getBroadcastVariable(modelVariableName);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.flink.ml.common.model;

import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.types.Row;

import java.io.Serializable;
import java.util.List;

/**
* An interface that load the model from different sources. E.g. broadcast variables, list of rows, etc.
*/
public interface ModelSource extends Serializable {

/**
* Get the rows that containing the model.
*
* @return the rows that containing the model.
*/
List<Row> getModelRows(RuntimeContext runtimeContext);
}
Loading

0 comments on commit 005bda9

Please sign in to comment.