Skip to content

Commit

Permalink
Merge branch 'add_cvm_plugin' into 'master'
Browse files Browse the repository at this point in the history
cvm plugin draft

See merge request deep-learning/tensornet!18
  • Loading branch information
gzm55 committed Dec 25, 2024
2 parents 74bfb58 + 44d4f66 commit 0e93816
Show file tree
Hide file tree
Showing 15 changed files with 174 additions and 47 deletions.
20 changes: 13 additions & 7 deletions core/kernels/sparse_table_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class SparseTablePullKernel : public AsyncOpKernel {
float* w_matrix = var_tensor->matrix<float>().data();

size_t emb_size = sizeof(float) * dim;
CHECK_EQ(emb_size, emb_buf.cutn(w_matrix + sign_index * dim, emb_size));
CHECK_EQ(emb_size, emb_buf.cutn(w_matrix + sign_index * dim , emb_size));
}
}

Expand All @@ -281,22 +281,26 @@ REGISTER_KERNEL_BUILDER(Name("SparseTablePull").Device(DEVICE_CPU),

struct SparsePushVarInfo {
public:
SparsePushVarInfo(const Tensor* t_value, const Tensor* t_grad)
SparsePushVarInfo(const Tensor* t_value, const Tensor* t_grad, const Tensor* t_labels)
: value(t_value)
, grad(t_grad) {
, grad(t_grad)
, labels(t_labels) {

const int64* feasign_vec = value->flat<int64>().data();
const int64* fea_label_vec = t_labels->flat<int64>().data();

std::map<uint64, int> sign_id_mapping;
for (int i = 0; i < value->NumElements(); ++i) {
uint64 sign = (uint64)feasign_vec[i];
int label = static_cast<int>(fea_label_vec[i]);
auto ret = sign_id_mapping.insert({sign, sign_id_mapping.size()});

if (ret.second) {
virtual_sign_infos.emplace_back(sign, 1);
virtual_sign_infos.emplace_back(sign, 1, label);
} else {
auto iter = ret.first;
virtual_sign_infos[iter->second].batch_show += 1;
virtual_sign_infos[iter->second].batch_click += label;
}
}
}
Expand All @@ -308,6 +312,7 @@ struct SparsePushVarInfo {
public:
const Tensor* value;
const Tensor* grad;
const Tensor* labels;

std::vector<SparsePushSignInfo> virtual_sign_infos;
};
Expand All @@ -321,16 +326,17 @@ class SparseTablePushKernel : public AsyncOpKernel {
}

void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
OP_REQUIRES_ASYNC(c, c->num_inputs() == N_ * 2,
OP_REQUIRES_ASYNC(c, c->num_inputs() == N_ * 3,
errors::InvalidArgument("SparseTable push num_inputs:",
c->num_inputs(),
" not equal:", N_ * 2),
" not equal:", N_ * 3),
done);
std::vector<SparsePushVarInfo> var_infos;

for (int i = 0; i < N_; i++) {
const Tensor* value = &c->input(i);
const Tensor* grad = &c->input(N_ + i);
const Tensor* labels = &c->input(2 * N_ + i);

OP_REQUIRES_ASYNC(
c, TensorShapeUtils::IsMatrix(grad->shape()),
Expand All @@ -339,7 +345,7 @@ class SparseTablePushKernel : public AsyncOpKernel {
grad->shape().DebugString()),
done);

var_infos.emplace_back(value, grad);
var_infos.emplace_back(value, grad, labels);
}

CHECK_GT(var_infos.size(), 0);
Expand Down
4 changes: 3 additions & 1 deletion core/main/py_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ PYBIND11_MODULE(_pywrap_tn, m) {

return py::reinterpret_steal<py::object>(obj);
})
.def("create_sparse_table", [](py::object obj, std::string name, int dimension) {
.def("create_sparse_table", [](py::object obj, std::string name, int dimension, bool use_cvm) {
OptimizerBase* opt =
static_cast<OptimizerBase*>(PyCapsule_GetPointer(obj.ptr(), nullptr));

opt->SetUseCvm(use_cvm);

PsCluster* cluster = PsCluster::Instance();

SparseTable* table = CreateSparseTable(opt, name, dimension, cluster->RankNum(), cluster->Rank());
Expand Down
1 change: 1 addition & 0 deletions core/ops/sparse_table_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ REGISTER_OP("SparseTablePush")
)doc")
.Input("values: N * int64")
.Input("grads: N * float")
.Input("feature_labels: N * int64")
.Attr("table_handle: int")
.Attr("N: int")
.SetShapeFn(shape_inference::NoOutputs);
39 changes: 38 additions & 1 deletion core/ps/optimizer/ada_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,22 @@ SparseAdaGradValue::SparseAdaGradValue(int dim, const AdaGrad* opt) {
}
}

use_cvm_ = opt->ShouldUseCvm();
g2sum_ = opt->initial_g2sum;
old_compat_ = false;
no_show_days_ = 0;
click_ = 0;
show_ = 0;
if(opt->ShouldUseCvm()){
w[dim] = 0;
w[dim+1] = 0;
}

}

void SparseAdaGradValue::Apply(const AdaGrad* opt, SparseGradInfo& grad_info, int dim) {
show_ += grad_info.batch_show;
click_ += grad_info.batch_click;
no_show_days_ = 0;

float* w = Weight();
Expand All @@ -116,6 +125,13 @@ void SparseAdaGradValue::Apply(const AdaGrad* opt, SparseGradInfo& grad_info, in
for (int i = 0; i < dim; ++i) {
w[i] -= opt->learning_rate * grad_info.grad[i] / (opt->epsilon + sqrt(g2sum_));
}
if(opt->ShouldUseCvm()){
float log_show = log(show_ + 1);
float log_click = log(click_ + 1);
w[dim] = show_;
w[dim+1] = (log_click - log_show);
}

}

void SparseAdaGradValue::SerializeTxt_(std::ostream& os, int dim) {
Expand All @@ -126,7 +142,10 @@ void SparseAdaGradValue::SerializeTxt_(std::ostream& os, int dim) {

os << g2sum_ << "\t";
os << show_ << "\t";
os << no_show_days_;
os << no_show_days_ << "\t";
if(use_cvm_){
os << click_;
}
}

void SparseAdaGradValue::DeSerializeTxt_(std::istream& is, int dim) {
Expand All @@ -139,6 +158,13 @@ void SparseAdaGradValue::DeSerializeTxt_(std::istream& is, int dim) {
is >> show_;
if(!old_compat_) {
is >> no_show_days_;
if(use_cvm_){
is >> click_;
float log_show = log(show_ + 1);
float log_click = log(click_ + 1);
Weight()[dim] = show_;
Weight()[dim+1] = (log_click - log_show);
}
}
}

Expand All @@ -147,6 +173,9 @@ void SparseAdaGradValue::SerializeBin_(std::ostream& os, int dim) {
os.write(reinterpret_cast<const char*>(&g2sum_), sizeof(g2sum_));
os.write(reinterpret_cast<const char*>(&show_), sizeof(show_));
os.write(reinterpret_cast<const char*>(&no_show_days_), sizeof(no_show_days_));
if(use_cvm_){
os.write(reinterpret_cast<const char*>(&click_), sizeof(click_));
}
}

void SparseAdaGradValue::DeSerializeBin_(std::istream& is, int dim) {
Expand All @@ -155,11 +184,19 @@ void SparseAdaGradValue::DeSerializeBin_(std::istream& is, int dim) {
is.read(reinterpret_cast<char*>(&show_), sizeof(show_));
if(!old_compat_) {
is.read(reinterpret_cast<char*>(&no_show_days_), sizeof(no_show_days_));
if(use_cvm_){
is.read(reinterpret_cast<char*>(&click_), sizeof(click_));
float log_show = log(show_ + 1);
float log_click = log(click_ + 1);
Weight()[dim] = show_;
Weight()[dim+1] = (log_click - log_show);
}
}
}

void SparseAdaGradValue::ShowDecay(const AdaGrad* opt, int delta_days) {
show_ *= opt->show_decay_rate;
click_ *= opt->show_decay_rate;
no_show_days_ += delta_days;
}

Expand Down
2 changes: 2 additions & 0 deletions core/ps/optimizer/ada_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class alignas(4) SparseAdaGradValue
int dim_;
float g2sum_;
float show_ = 0.0;
float click_ = 0.0;
int no_show_days_ = 0;
bool use_cvm_ = false;
float data_[0];
};

Expand Down
2 changes: 2 additions & 0 deletions core/ps/optimizer/data_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace tensornet {
struct SparseGradInfo {
float* grad;
int batch_show;
int batch_click;
};

extern int const SERIALIZE_FMT_ID;
Expand Down Expand Up @@ -56,6 +57,7 @@ class alignas(4) SparseOptValue {

protected:
float show_ = 0.0;
float click_ = 0.0;
int delta_show_ = 0;
bool old_compat_ = false;
};
Expand Down
14 changes: 12 additions & 2 deletions core/ps/optimizer/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,18 @@ class OptimizerBase {
return std::make_tuple(false, emptyString);
}

virtual void SetUseCvm(bool use_cvm) {
use_cvm_ = use_cvm;
}

virtual bool ShouldUseCvm() const {
return use_cvm_;
}

public:
float learning_rate = 0.01;
float show_decay_rate = 0.98;
float use_cvm_ = false;
};

class Adam : public OptimizerBase {
Expand Down Expand Up @@ -90,8 +99,9 @@ class AdaGrad : public OptimizerBase {
++column_count;
}

// columns should be sign, dim_, dims_ * weight, g2sum, show, no_show_days
// if columnCount is 12, means no no_show_days column
// if use cvm plugins, columns should be sign, dim_, dims_ * weight, g2sum, show, no_show_days, click,should be dim + 6
// if no use cvm, no click, should be dim + 5
// for old version, no no_show_days column, column_count should be dim + 4
if(column_count == dim + 4){
need_old_compat = true;
}
Expand Down
2 changes: 1 addition & 1 deletion core/ps/optimizer/optimizer_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class SparseKernelBlock {
SparseKernelBlock(const OptimizerBase* opt, int dimension)
: values_(15485863, sparse_key_hasher)
, dim_(dimension)
, alloc_(ValueType::DynSizeof(dim_), 1 << 16) {
, alloc_(ValueType::DynSizeof(dimension + opt->ShouldUseCvm() * 2), 1 << 16) {
values_.max_load_factor(0.75);
opt_ = dynamic_cast<const OptType*>(opt);
mutex_ = std::make_unique<std::mutex>();
Expand Down
9 changes: 4 additions & 5 deletions core/ps/table/sparse_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ void SparseTable::SetHandle(uint32_t handle) {
void SparseTable::Pull(const SparsePullRequest* req, butil::IOBuf& out_emb_buf, SparsePullResponse* resp) {
resp->set_table_handle(req->table_handle());

CHECK_EQ(dim_, req->dim());
resp->set_dim(req->dim());

for (int i = 0; i < req->signs_size(); ++i) {
Expand All @@ -57,23 +56,23 @@ void SparseTable::Pull(const SparsePullRequest* req, butil::IOBuf& out_emb_buf,
float* w = op_kernel_->GetWeight(sign);
CHECK(nullptr != w);

out_emb_buf.append(w, sizeof(float) * dim_);
out_emb_buf.append(w, sizeof(float) * (req->dim()));
}
}

void SparseTable::Push(const SparsePushRequest* req, butil::IOBuf& grad_buf, SparsePushResponse* resp) {
CHECK_EQ(dim_, req->dim());

float grad[dim_];
float grad[req->dim()];
SparsePushSignInfo sign_info;

while (sizeof(sign_info) == grad_buf.cutn(&sign_info, sizeof(sign_info))) {
size_t grad_size = sizeof(float) * dim_;
size_t grad_size = sizeof(float) * req->dim();
CHECK_EQ(grad_size, grad_buf.cutn(grad, grad_size));

SparseGradInfo grad_info;
grad_info.grad = grad;
grad_info.batch_show = sign_info.batch_show;
grad_info.batch_click = sign_info.batch_click;

op_kernel_->Apply(sign_info.sign, grad_info);
}
Expand Down
6 changes: 4 additions & 2 deletions core/ps_interface/ps_raw_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ namespace tensornet {
struct SparsePushSignInfo {
public:
SparsePushSignInfo()
: SparsePushSignInfo(0, 0)
: SparsePushSignInfo(0, 0, 0)
{ }

SparsePushSignInfo(uint64_t s, int bs)
SparsePushSignInfo(uint64_t s, int bs, int cs)
: sign(s)
, batch_show(bs)
, batch_click(cs)
{ }

uint64_t sign;
int batch_show;
int batch_click;
};

} // namespace tensornet
Expand Down
5 changes: 4 additions & 1 deletion examples/common/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ def create_emb_model(features, columns_group, suffix = "_input"):
for slot in features:
inputs[slot] = tf.keras.layers.Input(name=slot, shape=(None,), dtype="int64", sparse=True)

inputs["label"] = tf.keras.layers.Input(name="label", shape=(None,), dtype="int64", sparse=False)

sparse_opt = tn.core.AdaGrad(learning_rate=0.01, initial_g2sum=0.1, initial_scale=0.1)

for column_group_name in columns_group.keys():
embs = tn.layers.EmbeddingFeatures(columns_group[column_group_name], sparse_opt,
name=column_group_name + suffix)(inputs)
name=column_group_name + suffix, target_columns=["label"])(inputs)
#name=column_group_name + suffix)(inputs)
model_output.append(embs)

emb_model = tn.model.Model(inputs=inputs, outputs=model_output, name="emb_model")
Expand Down
2 changes: 1 addition & 1 deletion examples/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def parse_line_batch(example_proto):
fea_desc[slot] = tf.io.VarLenFeature(tf.int64)

feature_dict = tf.io.parse_example(example_proto, fea_desc)
label = feature_dict.pop('label')
label = feature_dict.pop['label']
return feature_dict, label

def create_model():
Expand Down
16 changes: 9 additions & 7 deletions examples/models/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def create_sub_model(linear_embs, deep_embs, deep_hidden_units):
for i, unit in enumerate(C.DEEP_HIDDEN_UNITS):
deep = tf.keras.layers.Dense(unit, activation='relu', name='dnn_{}'.format(i))(deep)

if linear_inputs and not deep_inputs:
output = tf.keras.layers.Dense(1, activation='sigmoid', name='pred')(linear)
elif deep_inputs and not linear_inputs:
output = tf.keras.layers.Dense(1, activation='sigmoid', name='pred')(deep)
else:
both = tf.keras.layers.concatenate([deep, linear], name='both')
output = tf.keras.layers.Dense(1, activation='sigmoid', name='pred')(both)
# if linear_inputs and not deep_inputs:
# output = tf.keras.layers.Dense(1, activation='sigmoid', name='pred')(linear)
# elif deep_inputs and not linear_inputs:
# output = tf.keras.layers.Dense(1, activation='sigmoid', name='pred')(deep)
# else:
both = tf.keras.layers.concatenate([deep, linear], name='both')
both = tn.layers.TNBatchNormalization(synchronized=True, sync_freq=4, max_count=1000000)(both)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='pred')(both)

return tn.model.Model(inputs=[linear_inputs, deep_inputs], outputs=output, name="sub_model")

Expand All @@ -45,6 +46,7 @@ def WideDeep(linear_features, dnn_features, dnn_hidden_units=(128, 128)):
inputs = {}
for slot in features:
inputs[slot] = tf.keras.layers.Input(name=slot, shape=(None,), dtype="int64", sparse=True)
inputs['label'] = tf.keras.layers.Input(name="label", shape=(None,), dtype="int64", sparse=False)
emb_model = create_emb_model(features, columns_group)
linear_embs, deep_embs = emb_model(inputs)
sub_model = create_sub_model(linear_embs, deep_embs, dnn_hidden_units)
Expand Down
Loading

0 comments on commit 0e93816

Please sign in to comment.